commit
dcc1187e1d
|
@ -43,12 +43,6 @@ import java.util.Random;
|
||||||
|
|
||||||
public class CapsnetGradientCheckTest extends BaseDL4JTest {
|
public class CapsnetGradientCheckTest extends BaseDL4JTest {
|
||||||
|
|
||||||
private static final boolean PRINT_RESULTS = true;
|
|
||||||
private static final boolean RETURN_ON_FIRST_FAILURE = false;
|
|
||||||
private static final double DEFAULT_EPS = 1e-6;
|
|
||||||
private static final double DEFAULT_MAX_REL_ERROR = 1e-3;
|
|
||||||
private static final double DEFAULT_MIN_ABS_ERROR = 1e-8;
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testCapsNet() {
|
public void testCapsNet() {
|
||||||
|
|
||||||
|
|
|
@ -43,10 +43,6 @@ import static org.junit.Assert.assertTrue;
|
||||||
public class OutputLayerGradientChecks extends BaseDL4JTest {
|
public class OutputLayerGradientChecks extends BaseDL4JTest {
|
||||||
|
|
||||||
private static final boolean PRINT_RESULTS = true;
|
private static final boolean PRINT_RESULTS = true;
|
||||||
private static final boolean RETURN_ON_FIRST_FAILURE = false;
|
|
||||||
private static final double DEFAULT_EPS = 1e-6;
|
|
||||||
private static final double DEFAULT_MAX_REL_ERROR = 1e-3;
|
|
||||||
private static final double DEFAULT_MIN_ABS_ERROR = 1e-8;
|
|
||||||
|
|
||||||
static {
|
static {
|
||||||
Nd4j.setDataType(DataType.DOUBLE);
|
Nd4j.setDataType(DataType.DOUBLE);
|
||||||
|
|
|
@ -47,10 +47,6 @@ import static org.junit.Assert.assertTrue;
|
||||||
public class RnnGradientChecks extends BaseDL4JTest {
|
public class RnnGradientChecks extends BaseDL4JTest {
|
||||||
|
|
||||||
private static final boolean PRINT_RESULTS = true;
|
private static final boolean PRINT_RESULTS = true;
|
||||||
private static final boolean RETURN_ON_FIRST_FAILURE = false;
|
|
||||||
private static final double DEFAULT_EPS = 1e-6;
|
|
||||||
private static final double DEFAULT_MAX_REL_ERROR = 1e-3;
|
|
||||||
private static final double DEFAULT_MIN_ABS_ERROR = 1e-8;
|
|
||||||
|
|
||||||
static {
|
static {
|
||||||
Nd4j.setDataType(DataType.DOUBLE);
|
Nd4j.setDataType(DataType.DOUBLE);
|
||||||
|
|
|
@ -48,12 +48,6 @@ import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
public class UtilLayerGradientChecks extends BaseDL4JTest {
|
public class UtilLayerGradientChecks extends BaseDL4JTest {
|
||||||
|
|
||||||
private static final boolean PRINT_RESULTS = true;
|
|
||||||
private static final boolean RETURN_ON_FIRST_FAILURE = false;
|
|
||||||
private static final double DEFAULT_EPS = 1e-6;
|
|
||||||
private static final double DEFAULT_MAX_REL_ERROR = 1e-3;
|
|
||||||
private static final double DEFAULT_MIN_ABS_ERROR = 1e-6;
|
|
||||||
|
|
||||||
static {
|
static {
|
||||||
Nd4j.setDataType(DataType.DOUBLE);
|
Nd4j.setDataType(DataType.DOUBLE);
|
||||||
}
|
}
|
||||||
|
@ -182,9 +176,9 @@ public class UtilLayerGradientChecks extends BaseDL4JTest {
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(input)
|
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net)
|
||||||
.minAbsoluteError(1e-7)
|
.minAbsoluteError(1e-6)
|
||||||
.labels(label).inputMask(inMask));
|
.input(input).labels(label).inputMask(inMask));
|
||||||
assertTrue(gradOK);
|
assertTrue(gradOK);
|
||||||
|
|
||||||
TestUtils.testModelSerialization(net);
|
TestUtils.testModelSerialization(net);
|
||||||
|
@ -233,8 +227,9 @@ public class UtilLayerGradientChecks extends BaseDL4JTest {
|
||||||
//Test ComputationGraph equivalent:
|
//Test ComputationGraph equivalent:
|
||||||
ComputationGraph g = net.toComputationGraph();
|
ComputationGraph g = net.toComputationGraph();
|
||||||
|
|
||||||
boolean gradOKCG = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(g).inputs(new INDArray[]{in})
|
boolean gradOKCG = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(g)
|
||||||
.labels(new INDArray[]{labels}).excludeParams(excludeParams));
|
.minAbsoluteError(1e-6)
|
||||||
|
.inputs(new INDArray[]{in}).labels(new INDArray[]{labels}).excludeParams(excludeParams));
|
||||||
assertTrue(gradOKCG);
|
assertTrue(gradOKCG);
|
||||||
|
|
||||||
TestUtils.testModelSerialization(g);
|
TestUtils.testModelSerialization(g);
|
||||||
|
|
|
@ -56,11 +56,6 @@ import static org.junit.Assert.assertTrue;
|
||||||
* @author Alex Black
|
* @author Alex Black
|
||||||
*/
|
*/
|
||||||
public class YoloGradientCheckTests extends BaseDL4JTest {
|
public class YoloGradientCheckTests extends BaseDL4JTest {
|
||||||
private static final boolean PRINT_RESULTS = true;
|
|
||||||
private static final boolean RETURN_ON_FIRST_FAILURE = false;
|
|
||||||
private static final double DEFAULT_EPS = 1e-6;
|
|
||||||
private static final double DEFAULT_MAX_REL_ERROR = 1e-3;
|
|
||||||
private static final double DEFAULT_MIN_ABS_ERROR = 1e-8;
|
|
||||||
|
|
||||||
static {
|
static {
|
||||||
Nd4j.setDataType(DataType.DOUBLE);
|
Nd4j.setDataType(DataType.DOUBLE);
|
||||||
|
|
|
@ -21,14 +21,13 @@ import org.deeplearning4j.nn.conf.ConvolutionMode;
|
||||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
|
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
|
||||||
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
|
import org.deeplearning4j.nn.conf.layers.*;
|
||||||
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
|
import org.deeplearning4j.nn.conf.layers.GlobalPoolingLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
|
||||||
import org.deeplearning4j.nn.conf.layers.PoolingType;
|
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.deeplearning4j.nn.weights.WeightInit;
|
import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMulOp;
|
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMulOp;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
@ -416,4 +415,53 @@ public class GlobalPoolingMaskingTests extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testMaskLayerDataTypes(){
|
||||||
|
|
||||||
|
for(DataType dt : new DataType[]{DataType.FLOAT16, DataType.BFLOAT16, DataType.FLOAT, DataType.DOUBLE,
|
||||||
|
DataType.INT8, DataType.INT16, DataType.INT32, DataType.INT64,
|
||||||
|
DataType.UINT8, DataType.UINT16, DataType.UINT32, DataType.UINT64}){
|
||||||
|
INDArray mask = Nd4j.rand(DataType.FLOAT, 2, 10).addi(0.3).castTo(dt);
|
||||||
|
|
||||||
|
for(DataType networkDtype : new DataType[]{DataType.FLOAT16, DataType.BFLOAT16, DataType.FLOAT, DataType.DOUBLE}){
|
||||||
|
|
||||||
|
INDArray in = Nd4j.rand(networkDtype, 2, 5, 10);
|
||||||
|
INDArray label1 = Nd4j.rand(networkDtype, 2, 5);
|
||||||
|
INDArray label2 = Nd4j.rand(networkDtype, 2, 5, 10);
|
||||||
|
|
||||||
|
for(PoolingType pt : PoolingType.values()) {
|
||||||
|
//System.out.println("Net: " + networkDtype + ", mask: " + dt + ", pt=" + pt);
|
||||||
|
|
||||||
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||||
|
.list()
|
||||||
|
.layer(new GlobalPoolingLayer(pt))
|
||||||
|
.layer(new OutputLayer.Builder().nIn(5).nOut(5).activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).build())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
|
net.init();
|
||||||
|
|
||||||
|
net.output(in, false, mask, null);
|
||||||
|
net.output(in, false, mask, null);
|
||||||
|
|
||||||
|
|
||||||
|
MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder()
|
||||||
|
|
||||||
|
.list()
|
||||||
|
.layer(new RnnOutputLayer.Builder().nIn(5).nOut(5).activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).build())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
|
||||||
|
net2.init();
|
||||||
|
|
||||||
|
net2.output(in, false, mask, mask);
|
||||||
|
net2.output(in, false, mask, mask);
|
||||||
|
|
||||||
|
net.fit(in, label1, mask, null);
|
||||||
|
net2.fit(in, label2, mask, mask);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,8 +19,8 @@
|
||||||
<appender name="FILE" class="ch.qos.logback.core.FileAppender">
|
<appender name="FILE" class="ch.qos.logback.core.FileAppender">
|
||||||
<file>logs/application.log</file>
|
<file>logs/application.log</file>
|
||||||
<encoder>
|
<encoder>
|
||||||
<pattern>%date - [%level] - from %logger in %thread
|
<pattern> %logger{15} - %message%n%xException{5}
|
||||||
%n%message%n%xException%n</pattern>
|
</pattern>
|
||||||
</encoder>
|
</encoder>
|
||||||
</appender>
|
</appender>
|
||||||
|
|
||||||
|
|
|
@ -19,8 +19,8 @@
|
||||||
<appender name="FILE" class="ch.qos.logback.core.FileAppender">
|
<appender name="FILE" class="ch.qos.logback.core.FileAppender">
|
||||||
<file>logs/application.log</file>
|
<file>logs/application.log</file>
|
||||||
<encoder>
|
<encoder>
|
||||||
<pattern>%date - [%level] - from %logger in %thread
|
<pattern> %logger{15} - %message%n%xException{5}
|
||||||
%n%message%n%xException%n</pattern>
|
</pattern>
|
||||||
</encoder>
|
</encoder>
|
||||||
</appender>
|
</appender>
|
||||||
|
|
||||||
|
|
|
@ -19,8 +19,8 @@
|
||||||
<appender name="FILE" class="ch.qos.logback.core.FileAppender">
|
<appender name="FILE" class="ch.qos.logback.core.FileAppender">
|
||||||
<file>logs/application.log</file>
|
<file>logs/application.log</file>
|
||||||
<encoder>
|
<encoder>
|
||||||
<pattern>%date - [%level] - from %logger in %thread
|
<pattern> %logger{15} - %message%n%xException{5}
|
||||||
%n%message%n%xException%n</pattern>
|
</pattern>
|
||||||
</encoder>
|
</encoder>
|
||||||
</appender>
|
</appender>
|
||||||
|
|
||||||
|
|
|
@ -131,7 +131,7 @@ public class RnnOutputLayer extends BaseOutputLayer<org.deeplearning4j.nn.conf.l
|
||||||
INDArray W = getParamWithNoise(DefaultParamInitializer.WEIGHT_KEY, training, workspaceMgr);
|
INDArray W = getParamWithNoise(DefaultParamInitializer.WEIGHT_KEY, training, workspaceMgr);
|
||||||
|
|
||||||
applyDropOutIfNecessary(training, workspaceMgr);
|
applyDropOutIfNecessary(training, workspaceMgr);
|
||||||
INDArray input2d = TimeSeriesUtils.reshape3dTo2d(input, LayerWorkspaceMgr.noWorkspaces(), ArrayType.FF_WORKING_MEM);
|
INDArray input2d = TimeSeriesUtils.reshape3dTo2d(input.castTo(W.dataType()), workspaceMgr, ArrayType.FF_WORKING_MEM);
|
||||||
|
|
||||||
INDArray act2d = layerConf().getActivationFn().getActivation(input2d.mmul(W).addiRowVector(b), training);
|
INDArray act2d = layerConf().getActivationFn().getActivation(input2d.mmul(W).addiRowVector(b), training);
|
||||||
if (maskArray != null) {
|
if (maskArray != null) {
|
||||||
|
|
|
@ -56,6 +56,7 @@ public class MaskedReductionUtil {
|
||||||
throw new IllegalArgumentException("Expect rank 2 array for mask: got " + mask.rank());
|
throw new IllegalArgumentException("Expect rank 2 array for mask: got " + mask.rank());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
toReduce = toReduce.castTo(dataType);
|
||||||
mask = mask.castTo(dataType);
|
mask = mask.castTo(dataType);
|
||||||
|
|
||||||
//Sum pooling: easy. Multiply by mask, then sum as normal
|
//Sum pooling: easy. Multiply by mask, then sum as normal
|
||||||
|
@ -64,13 +65,7 @@ public class MaskedReductionUtil {
|
||||||
|
|
||||||
switch (poolingType) {
|
switch (poolingType) {
|
||||||
case MAX:
|
case MAX:
|
||||||
//TODO This is ugly - replace it with something better... Need something like a Broadcast CAS op
|
INDArray negInfMask = mask.castTo(dataType).rsub(1.0);
|
||||||
INDArray negInfMask;
|
|
||||||
if(mask.dataType() == DataType.BOOL){
|
|
||||||
negInfMask = Transforms.not(mask).castTo(dataType);
|
|
||||||
} else {
|
|
||||||
negInfMask = mask.rsub(1.0);
|
|
||||||
}
|
|
||||||
BooleanIndexing.replaceWhere(negInfMask, Double.NEGATIVE_INFINITY, Conditions.equals(1.0));
|
BooleanIndexing.replaceWhere(negInfMask, Double.NEGATIVE_INFINITY, Conditions.equals(1.0));
|
||||||
|
|
||||||
INDArray withInf = Nd4j.createUninitialized(dataType, toReduce.shape());
|
INDArray withInf = Nd4j.createUninitialized(dataType, toReduce.shape());
|
||||||
|
@ -121,18 +116,14 @@ public class MaskedReductionUtil {
|
||||||
//Mask: [minibatch, tsLength]
|
//Mask: [minibatch, tsLength]
|
||||||
//Epsilon: [minibatch, vectorSize]
|
//Epsilon: [minibatch, vectorSize]
|
||||||
|
|
||||||
|
mask = mask.castTo(input.dataType());
|
||||||
|
|
||||||
switch (poolingType) {
|
switch (poolingType) {
|
||||||
case MAX:
|
case MAX:
|
||||||
//TODO This is ugly - replace it with something better... Need something like a Broadcast CAS op
|
INDArray negInfMask = mask.rsub(1.0);
|
||||||
INDArray negInfMask;
|
|
||||||
if(mask.dataType() == DataType.BOOL){
|
|
||||||
negInfMask = Transforms.not(mask).castTo(Nd4j.defaultFloatingPointType());
|
|
||||||
} else {
|
|
||||||
negInfMask = mask.rsub(1.0);
|
|
||||||
}
|
|
||||||
BooleanIndexing.replaceWhere(negInfMask, Double.NEGATIVE_INFINITY, Conditions.equals(1.0));
|
BooleanIndexing.replaceWhere(negInfMask, Double.NEGATIVE_INFINITY, Conditions.equals(1.0));
|
||||||
|
|
||||||
INDArray withInf = Nd4j.createUninitialized(input.shape());
|
INDArray withInf = Nd4j.createUninitialized(input.dataType(), input.shape());
|
||||||
Nd4j.getExecutioner().exec(new BroadcastAddOp(input, negInfMask, withInf, 0, 2));
|
Nd4j.getExecutioner().exec(new BroadcastAddOp(input, negInfMask, withInf, 0, 2));
|
||||||
//At this point: all the masked out steps have value -inf, hence can't be the output of the MAX op
|
//At this point: all the masked out steps have value -inf, hence can't be the output of the MAX op
|
||||||
|
|
||||||
|
@ -145,7 +136,7 @@ public class MaskedReductionUtil {
|
||||||
//if out = avg(in,dims) then dL/dIn = 1/N * dL/dOut
|
//if out = avg(in,dims) then dL/dIn = 1/N * dL/dOut
|
||||||
//With masking: N differs for different time series
|
//With masking: N differs for different time series
|
||||||
|
|
||||||
INDArray out = Nd4j.createUninitialized(input.shape(), 'f');
|
INDArray out = Nd4j.createUninitialized(input.dataType(), input.shape(), 'f');
|
||||||
|
|
||||||
//Broadcast copy op, then divide and mask to 0 as appropriate
|
//Broadcast copy op, then divide and mask to 0 as appropriate
|
||||||
Nd4j.getExecutioner().exec(new BroadcastCopyOp(out, epsilon2d, out, 0, 1));
|
Nd4j.getExecutioner().exec(new BroadcastCopyOp(out, epsilon2d, out, 0, 1));
|
||||||
|
@ -162,7 +153,7 @@ public class MaskedReductionUtil {
|
||||||
|
|
||||||
case PNORM:
|
case PNORM:
|
||||||
//Similar to average and sum pooling: there's no N term here, so we can just set the masked values to 0
|
//Similar to average and sum pooling: there's no N term here, so we can just set the masked values to 0
|
||||||
INDArray masked2 = Nd4j.createUninitialized(input.shape());
|
INDArray masked2 = Nd4j.createUninitialized(input.dataType(), input.shape());
|
||||||
Nd4j.getExecutioner().exec(new BroadcastMulOp(input, mask, masked2, 0, 2));
|
Nd4j.getExecutioner().exec(new BroadcastMulOp(input, mask, masked2, 0, 2));
|
||||||
|
|
||||||
INDArray abs = Transforms.abs(masked2, true);
|
INDArray abs = Transforms.abs(masked2, true);
|
||||||
|
|
|
@ -21,8 +21,8 @@
|
||||||
<appender name="FILE" class="ch.qos.logback.core.FileAppender">
|
<appender name="FILE" class="ch.qos.logback.core.FileAppender">
|
||||||
<file>logs/application.log</file>
|
<file>logs/application.log</file>
|
||||||
<encoder>
|
<encoder>
|
||||||
<pattern>%date - [%level] - from %logger in %thread
|
<pattern> %logger{15} - %message%n%xException{5}
|
||||||
%n%message%n%xException%n</pattern>
|
</pattern>
|
||||||
</encoder>
|
</encoder>
|
||||||
</appender>
|
</appender>
|
||||||
|
|
||||||
|
|
|
@ -21,8 +21,8 @@
|
||||||
<appender name="FILE" class="ch.qos.logback.core.FileAppender">
|
<appender name="FILE" class="ch.qos.logback.core.FileAppender">
|
||||||
<file>logs/application.log</file>
|
<file>logs/application.log</file>
|
||||||
<encoder>
|
<encoder>
|
||||||
<pattern>%date - [%level] - from %logger in %thread
|
<pattern> %logger{15} - %message%n%xException{5}
|
||||||
%n%message%n%xException%n</pattern>
|
</pattern>
|
||||||
</encoder>
|
</encoder>
|
||||||
</appender>
|
</appender>
|
||||||
|
|
||||||
|
|
|
@ -21,8 +21,8 @@
|
||||||
<appender name="FILE" class="ch.qos.logback.core.FileAppender">
|
<appender name="FILE" class="ch.qos.logback.core.FileAppender">
|
||||||
<file>logs/application.log</file>
|
<file>logs/application.log</file>
|
||||||
<encoder>
|
<encoder>
|
||||||
<pattern>%date - [%level] - from %logger in %thread
|
<pattern> %logger{15} - %message%n%xException{5}
|
||||||
%n%message%n%xException%n</pattern>
|
</pattern>
|
||||||
</encoder>
|
</encoder>
|
||||||
</appender>
|
</appender>
|
||||||
|
|
||||||
|
|
|
@ -21,8 +21,8 @@
|
||||||
<appender name="FILE" class="ch.qos.logback.core.FileAppender">
|
<appender name="FILE" class="ch.qos.logback.core.FileAppender">
|
||||||
<file>logs/application.log</file>
|
<file>logs/application.log</file>
|
||||||
<encoder>
|
<encoder>
|
||||||
<pattern>%date - [%level] - from %logger in %thread
|
<pattern> %logger{15} - %message%n%xException{5}
|
||||||
%n%message%n%xException%n</pattern>
|
</pattern>
|
||||||
</encoder>
|
</encoder>
|
||||||
</appender>
|
</appender>
|
||||||
|
|
||||||
|
|
|
@ -21,8 +21,8 @@
|
||||||
<appender name="FILE" class="ch.qos.logback.core.FileAppender">
|
<appender name="FILE" class="ch.qos.logback.core.FileAppender">
|
||||||
<file>logs/application.log</file>
|
<file>logs/application.log</file>
|
||||||
<encoder>
|
<encoder>
|
||||||
<pattern>%date - [%level] - from %logger in %thread
|
<pattern> %logger{15} - %message%n%xException{5}
|
||||||
%n%message%n%xException%n</pattern>
|
</pattern>
|
||||||
</encoder>
|
</encoder>
|
||||||
</appender>
|
</appender>
|
||||||
|
|
||||||
|
|
|
@ -21,8 +21,8 @@
|
||||||
<appender name="FILE" class="ch.qos.logback.core.FileAppender">
|
<appender name="FILE" class="ch.qos.logback.core.FileAppender">
|
||||||
<file>logs/application.log</file>
|
<file>logs/application.log</file>
|
||||||
<encoder>
|
<encoder>
|
||||||
<pattern>%date - [%level] - from %logger in %thread
|
<pattern> %logger{15} - %message%n%xException{5}
|
||||||
%n%message%n%xException%n</pattern>
|
</pattern>
|
||||||
</encoder>
|
</encoder>
|
||||||
</appender>
|
</appender>
|
||||||
|
|
||||||
|
|
|
@ -21,8 +21,8 @@
|
||||||
<appender name="FILE" class="ch.qos.logback.core.FileAppender">
|
<appender name="FILE" class="ch.qos.logback.core.FileAppender">
|
||||||
<file>logs/application.log</file>
|
<file>logs/application.log</file>
|
||||||
<encoder>
|
<encoder>
|
||||||
<pattern>%date - [%level] - from %logger in %thread
|
<pattern> %logger{15} - %message%n%xException{5}
|
||||||
%n%message%n%xException%n</pattern>
|
</pattern>
|
||||||
</encoder>
|
</encoder>
|
||||||
</appender>
|
</appender>
|
||||||
|
|
||||||
|
|
|
@ -21,8 +21,8 @@
|
||||||
<appender name="FILE" class="ch.qos.logback.core.FileAppender">
|
<appender name="FILE" class="ch.qos.logback.core.FileAppender">
|
||||||
<file>logs/application.log</file>
|
<file>logs/application.log</file>
|
||||||
<encoder>
|
<encoder>
|
||||||
<pattern>%date - [%level] - from %logger in %thread
|
<pattern> %logger{15} - %message%n%xException{5}
|
||||||
%n%message%n%xException%n</pattern>
|
</pattern>
|
||||||
</encoder>
|
</encoder>
|
||||||
</appender>
|
</appender>
|
||||||
|
|
||||||
|
|
|
@ -21,8 +21,8 @@
|
||||||
<appender name="FILE" class="ch.qos.logback.core.FileAppender">
|
<appender name="FILE" class="ch.qos.logback.core.FileAppender">
|
||||||
<file>logs/application.log</file>
|
<file>logs/application.log</file>
|
||||||
<encoder>
|
<encoder>
|
||||||
<pattern>%date - [%level] - from %logger in %thread
|
<pattern> %logger{15} - %message%n%xException{5}
|
||||||
%n%message%n%xException%n</pattern>
|
</pattern>
|
||||||
</encoder>
|
</encoder>
|
||||||
</appender>
|
</appender>
|
||||||
|
|
||||||
|
|
|
@ -21,8 +21,8 @@
|
||||||
<appender name="FILE" class="ch.qos.logback.core.FileAppender">
|
<appender name="FILE" class="ch.qos.logback.core.FileAppender">
|
||||||
<file>logs/application.log</file>
|
<file>logs/application.log</file>
|
||||||
<encoder>
|
<encoder>
|
||||||
<pattern>%date - [%level] - from %logger in %thread
|
<pattern> %logger{15} - %message%n%xException{5}
|
||||||
%n%message%n%xException%n</pattern>
|
</pattern>
|
||||||
</encoder>
|
</encoder>
|
||||||
</appender>
|
</appender>
|
||||||
|
|
||||||
|
|
|
@ -19,8 +19,8 @@
|
||||||
<appender name="FILE" class="ch.qos.logback.core.FileAppender">
|
<appender name="FILE" class="ch.qos.logback.core.FileAppender">
|
||||||
<file>logs/application.log</file>
|
<file>logs/application.log</file>
|
||||||
<encoder>
|
<encoder>
|
||||||
<pattern>%date - [%level] - from %logger in %thread
|
<pattern> %logger{15} - %message%n%xException{5}
|
||||||
%n%message%n%xException%n</pattern>
|
</pattern>
|
||||||
</encoder>
|
</encoder>
|
||||||
</appender>
|
</appender>
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,7 @@ option(NATIVE "Optimize for build machine (might not work on others)" OFF)
|
||||||
set(CMAKE_MODULE_PATH "${CMAKE_SOURCE_DIR}/cmake" ${CMAKE_MODULE_PATH})
|
set(CMAKE_MODULE_PATH "${CMAKE_SOURCE_DIR}/cmake" ${CMAKE_MODULE_PATH})
|
||||||
#ensure we create lib files
|
#ensure we create lib files
|
||||||
set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS OFF)
|
set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS OFF)
|
||||||
|
option(CHECK_VECTORIZATION "checks for vectorization" OFF)
|
||||||
option(BUILD_TESTS "Build tests" OFF)
|
option(BUILD_TESTS "Build tests" OFF)
|
||||||
option(FLATBUFFERS_BUILD_FLATC "Enable the build of the flatbuffers compiler" OFF)
|
option(FLATBUFFERS_BUILD_FLATC "Enable the build of the flatbuffers compiler" OFF)
|
||||||
set(FLATBUFFERS_BUILD_FLATC "OFF" CACHE STRING "Hack to disable flatc build" FORCE)
|
set(FLATBUFFERS_BUILD_FLATC "OFF" CACHE STRING "Hack to disable flatc build" FORCE)
|
||||||
|
|
|
@ -5,7 +5,7 @@ project(mkldnn-download NONE)
|
||||||
include(ExternalProject)
|
include(ExternalProject)
|
||||||
ExternalProject_Add(mkldnn
|
ExternalProject_Add(mkldnn
|
||||||
GIT_REPOSITORY https://github.com/intel/mkl-dnn.git
|
GIT_REPOSITORY https://github.com/intel/mkl-dnn.git
|
||||||
GIT_TAG v1.1.2
|
GIT_TAG v1.1.3
|
||||||
SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-src"
|
SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-src"
|
||||||
BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-build"
|
BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-build"
|
||||||
CONFIGURE_COMMAND ""
|
CONFIGURE_COMMAND ""
|
||||||
|
|
|
@ -17,8 +17,11 @@ There's few additional arguments for `buildnativeoperations.sh` script you could
|
||||||
-b release OR -b debug // enables/desables debug builds. release is considered by default
|
-b release OR -b debug // enables/desables debug builds. release is considered by default
|
||||||
-j XX // this argument defines how many threads will be used to binaries on your box. i.e. -j 8
|
-j XX // this argument defines how many threads will be used to binaries on your box. i.e. -j 8
|
||||||
-cc XX// CUDA-only argument, builds only binaries for target GPU architecture. use this for fast builds
|
-cc XX// CUDA-only argument, builds only binaries for target GPU architecture. use this for fast builds
|
||||||
|
--check-vectorization auto-vectorization report for developers. (Currently, only GCC is supported)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
[More about AutoVectorization report](auto_vectorization/AutoVectorization.md)
|
||||||
|
|
||||||
You can find the compute capability for your card [on the NVIDIA website here](https://developer.nvidia.com/cuda-gpus).
|
You can find the compute capability for your card [on the NVIDIA website here](https://developer.nvidia.com/cuda-gpus).
|
||||||
|
|
||||||
For example, a GTX 1080 has compute capability 6.1, for which you would use ```-cc 61``` (note no decimal point).
|
For example, a GTX 1080 has compute capability 6.1, for which you would use ```-cc 61``` (note no decimal point).
|
||||||
|
|
|
@ -0,0 +1,49 @@
|
||||||
|
# Auto-vectorization Report
|
||||||
|
|
||||||
|
This report tool is used to get a human-friendly compiler output of the auto-vectorization process. It is intended for developers to help them to investigate the obstacles that compiler faced during auto-vectorization.
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
```--check-vectorization``` option should be added to the **release** build to be able to get the auto-vectorization report
|
||||||
|
```./buildnativeoperations.sh -a native -j 28 --check-vectorization```
|
||||||
|
it will output ```vecmiss.html``` inside blasbuild/cpu folder.
|
||||||
|
|
||||||
|
## Report Format
|
||||||
|
Each filename contains info about optimization attempts for the source code lines.
|
||||||
|
Each line number is also expandable (⇲) and contains distinct failure notes.
|
||||||
|
It is possible to click on the line number to see source code
|
||||||
|
|
||||||
|
| file name | total successful attempts | total failed attempts | ⇲ |
|
||||||
|
|---|---|---|--|
|
||||||
|
| line number | successful attempts | failed attempts | ⇲ |
|
||||||
|
|- failure reasons |
|
||||||
|
| line number | successful attempts | failed attempts |⇲ |
|
||||||
|
|
||||||
|
##### Requirements
|
||||||
|
- GCC (Currently, only GCC is supported)
|
||||||
|
- python3
|
||||||
|
|
||||||
|
### Detailed report with `-fsave-optimization-record` option:
|
||||||
|
If you want to get more detailed information (for now it reports the functions of failures) you should use new version of the toolchain (GCC > 9). As the new version of GCC compilers have `-fsave-optimization-record` option.
|
||||||
|
`buildnativeoperations.sh` using CMake will detect it and switch to the more detailed version.
|
||||||
|
Please, note that this option is still experimental and so the compiler can fail to output some json.gz file with error.
|
||||||
|
On that case try to exclude those files from the build.
|
||||||
|
And also the internal structure of the `-fsave-optimization-record` json.gz can be changed in future.
|
||||||
|
|
||||||
|
It outputs two files **vecmiss_fsave.html** and **vecmiss_fsave.html.js**. So to see report details you need to enable javascript on browser if it was disabled.
|
||||||
|
|
||||||
|
##### Requirements for the Detailed report
|
||||||
|
- GCC version > 9
|
||||||
|
- python3
|
||||||
|
- Cython (python3)
|
||||||
|
- json (python3)
|
||||||
|
- gzip (python3)
|
||||||
|
- c++filt
|
||||||
|
|
||||||
|
Internally, we are using Cython to speed up json.gz file processing (bigGzipJson.pyx). Because json.gz files can take big memory in raw when loaded in whole.
|
||||||
|
|
||||||
|
If you want to use bigGzipJson outside `buildnativeoperations.sh` and CMake then you should compile it manually using this command in auto_vectorization folder:
|
||||||
|
`python3 cython_setup.py build_ext --inplace`
|
||||||
|
|
||||||
|
json.gz files could be processed outside of `buildnativeoperations.sh`.
|
||||||
|
You need to call `python3 auto_vect.py --fsave` inside base source folder and where json.gz files exist.
|
||||||
|
|
|
@ -0,0 +1,546 @@
|
||||||
|
'''
|
||||||
|
@author : Abdelrauf rauf@konduit.ai
|
||||||
|
'''
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import fnmatch
|
||||||
|
import json
|
||||||
|
import gzip
|
||||||
|
try:
|
||||||
|
from bigGzipJson import json_gzip_extract_objects
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
from pathlib import Path
|
||||||
|
from multiprocessing import Pool, Manager ,cpu_count
|
||||||
|
import traceback
|
||||||
|
import html
|
||||||
|
|
||||||
|
mtch = re.compile(r"[^/]*([^:]+)\:(\d+)\:(\d+)\:(.*)")
|
||||||
|
replace_msg = re.compile(r"(\d+)?\.?(\d+)?_?\d+\.?(\d+)?")
|
||||||
|
progress_msg = re.compile(r"\s{0,4}\[\s{0,2}\d+\%\]")
|
||||||
|
file_dir_strip = str(Path(os.getcwd()))
|
||||||
|
pp_index = file_dir_strip.rfind("libnd4j")
|
||||||
|
if pp_index>=0:
|
||||||
|
file_dir_strip =file_dir_strip[:pp_index+len("libnd4j")]
|
||||||
|
BASE_URL = "https://github.com/eclipse/deeplearning4j/tree/master/libnd4j/"
|
||||||
|
if BASE_URL.endswith("/")==False:
|
||||||
|
BASE_URL = BASE_URL + "/"
|
||||||
|
#print(file_dir_strip)
|
||||||
|
class info:
|
||||||
|
def __repr__(self):
|
||||||
|
return str(self.__dict__)
|
||||||
|
|
||||||
|
FSAVE_IGNORE_EXTERNALS = True
|
||||||
|
|
||||||
|
def get_cxx_filt_result(strx):
|
||||||
|
if len(strx)<1:
|
||||||
|
return ""
|
||||||
|
res = subprocess.Popen(["c++filt","-i", strx], stdout=subprocess.PIPE).communicate()[0]
|
||||||
|
res =res.decode('utf-8')
|
||||||
|
#replace some long names to reduce size
|
||||||
|
res = res.replace("unsigned long long", "uLL")
|
||||||
|
res = res.replace("unsigned long int","uL")
|
||||||
|
res = res.replace("unsigned long", "uL")
|
||||||
|
res = res.replace("unsigned int", "ui")
|
||||||
|
res = res.replace("unsigned char", "uchar")
|
||||||
|
res = res.replace("unsigned short", "ushort")
|
||||||
|
res = res.replace("long long", "LL")
|
||||||
|
res = res.replace(", ",",")
|
||||||
|
return res.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def internal_glob(dir, match):
|
||||||
|
listx = []
|
||||||
|
for root, dirnames, filenames in os.walk(dir):
|
||||||
|
for filename in fnmatch.filter(filenames, match):
|
||||||
|
listx.append(os.path.join(root, filename))
|
||||||
|
return listx
|
||||||
|
|
||||||
|
def get_obj_json_gz(filename):
|
||||||
|
with gzip.GzipFile(filename, 'r') as f:
|
||||||
|
return json.loads(f.read().decode('utf-8'))[-1]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def get_msg(msg):
|
||||||
|
msg = msg.lower().strip()
|
||||||
|
if "note: not vectorized:" in msg:
|
||||||
|
msg = replace_msg.sub("_numb",msg.replace("note: not vectorized:",""))
|
||||||
|
return( 0, 1, msg.strip())
|
||||||
|
elif "loop vectorized" in msg:
|
||||||
|
return (1, 0, None)
|
||||||
|
# elif msg.startswith("missed")==False:
|
||||||
|
# msg = replace_msg.sub("_numb",msg)
|
||||||
|
# return( 0, 0, msg.strip())
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class File_Info:
|
||||||
|
'''
|
||||||
|
Holds information about vectorized and miss vectorized lines for one file
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.infos = {}
|
||||||
|
self.total_opted =0
|
||||||
|
self.total_missed = 0
|
||||||
|
self.external = False
|
||||||
|
|
||||||
|
|
||||||
|
def add_line(self, line_pos):
|
||||||
|
if line_pos not in self.infos:
|
||||||
|
v = info()
|
||||||
|
v.optimized = 0
|
||||||
|
v.missed = 0
|
||||||
|
v.miss_details = set()
|
||||||
|
self.infos[line_pos] = v
|
||||||
|
return v
|
||||||
|
else:
|
||||||
|
return self.infos[line_pos]
|
||||||
|
|
||||||
|
|
||||||
|
def add_line_fsave(self, line_pos):
|
||||||
|
if line_pos not in self.infos:
|
||||||
|
v = info()
|
||||||
|
v.optimized = 0
|
||||||
|
v.missed = 0
|
||||||
|
v.miss_details2 = dict()
|
||||||
|
self.infos[line_pos] = v
|
||||||
|
return v
|
||||||
|
else:
|
||||||
|
return self.infos[line_pos]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def add_fsave(self, line_pos,success, msg, function ,inline_fns=''):
|
||||||
|
v = self.add_line_fsave(line_pos)
|
||||||
|
if success and "loop vectorized" in msg:
|
||||||
|
v.optimized +=1
|
||||||
|
self.total_opted +=1
|
||||||
|
elif success==False and "not vectorized:" in msg:
|
||||||
|
#reduce this msg
|
||||||
|
msg = msg.replace("not vectorized:","")
|
||||||
|
v.missed +=1
|
||||||
|
self.total_missed +=1
|
||||||
|
msg = sys.intern(msg)
|
||||||
|
if msg in v.miss_details2:
|
||||||
|
ls = v.miss_details2.get(msg)
|
||||||
|
ls.add(function)
|
||||||
|
else:
|
||||||
|
ls =set()
|
||||||
|
v.miss_details2[msg]=ls
|
||||||
|
ls.add(function)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def add(self, line_pos, msg_x):
|
||||||
|
v = self.add_line(line_pos)
|
||||||
|
if msg_x is not None:
|
||||||
|
v.optimized += msg_x[0]
|
||||||
|
v.missed += msg_x[1]
|
||||||
|
self.total_opted += msg_x[0]
|
||||||
|
self.total_missed += msg_x[1]
|
||||||
|
if msg_x[2] is not None:
|
||||||
|
v.miss_details.add(msg_x[2])
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return str(self.__dict__)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def process_gzip_json_mp(args):
|
||||||
|
process_gzip_json_new(*args)
|
||||||
|
|
||||||
|
def process_gzip_json_new(json_gz_fname,list_Queue):
|
||||||
|
gz_name = Path(json_gz_fname).stem
|
||||||
|
#print("::--open and process {0}".format(gz_name))
|
||||||
|
queue_count = len(list_Queue)
|
||||||
|
#print(queue_count)
|
||||||
|
q = list_Queue[0]
|
||||||
|
old_fname = ''
|
||||||
|
total_c = 0
|
||||||
|
for x in json_gzip_extract_objects(json_gz_fname,'message','vectorized'):
|
||||||
|
external_source = True
|
||||||
|
if len(x['message'])>0 and 'location' in x:
|
||||||
|
line = int(x['location']['line'])
|
||||||
|
file_name = x['location']['file'].strip()
|
||||||
|
if file_dir_strip in file_name:
|
||||||
|
file_name = file_name.replace(file_dir_strip,'./')
|
||||||
|
external_source = False
|
||||||
|
msg = x['message'][0]
|
||||||
|
success = x['kind'] == 'success'
|
||||||
|
func = '' if 'function' not in x else x['function']
|
||||||
|
|
||||||
|
if file_name!=old_fname:
|
||||||
|
#send our info to the right consumer
|
||||||
|
queue_ind = hash(file_name) % queue_count
|
||||||
|
#print("quen index {0}".format(queue_ind))
|
||||||
|
q =list_Queue[queue_ind]
|
||||||
|
old_fname = file_name
|
||||||
|
total_c +=1
|
||||||
|
#print("pp {0} {1}".format(q,(file_name,line,success, msg, func,external_source )))
|
||||||
|
if FSAVE_IGNORE_EXTERNALS==True and external_source == True:
|
||||||
|
continue
|
||||||
|
q.put((file_name,line,success, msg, func,external_source ))
|
||||||
|
print("::finished {0:60s} :{1:8d}".format(gz_name,total_c))
|
||||||
|
|
||||||
|
def consume_processed_mp(args):
|
||||||
|
return consume_processed_new(*args)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def consume_processed_new(list_Queue , c_index):
|
||||||
|
|
||||||
|
info_ = dict()
|
||||||
|
func_list = dict()
|
||||||
|
last_func_index = 0
|
||||||
|
q = list_Queue[c_index]
|
||||||
|
print("::consumer {0}".format(c_index))
|
||||||
|
total_c = 0
|
||||||
|
r_c = 0
|
||||||
|
while True:
|
||||||
|
#print("try to get new from {0}".format(index))
|
||||||
|
obj = q.get()
|
||||||
|
#print("cc {0} {1}".format(q,obj))
|
||||||
|
if obj==None:
|
||||||
|
break #we received the end
|
||||||
|
file_name,line,success, msg, func, external_source = obj
|
||||||
|
try:
|
||||||
|
#get function index
|
||||||
|
func_index = -1
|
||||||
|
if func in func_list:
|
||||||
|
func_index = func_list[func]
|
||||||
|
else:
|
||||||
|
func_list[func] = last_func_index
|
||||||
|
func_index = last_func_index
|
||||||
|
last_func_index +=1
|
||||||
|
|
||||||
|
if file_name in info_:
|
||||||
|
info_[file_name].add_fsave(line, success, msg, func_index)
|
||||||
|
else:
|
||||||
|
info_[file_name] = File_Info().add_fsave(line, success, msg, func_index)
|
||||||
|
info_[file_name].external = external_source
|
||||||
|
total_c +=1
|
||||||
|
if total_c - r_c >10000:
|
||||||
|
r_c = total_c
|
||||||
|
print("::consumer {0:2d} :{1:10d}".format(c_index,total_c))
|
||||||
|
except Exception as e:
|
||||||
|
print(traceback.format_exc())
|
||||||
|
break
|
||||||
|
|
||||||
|
print("::consumer {0:2d} :{1:10d}".format(c_index,total_c))
|
||||||
|
#write to temp file
|
||||||
|
wr_fname= "vecmiss_fsave{0}.html".format(str(c_index) if len(list_Queue)>1 else '')
|
||||||
|
print("generate report for consumer {0} {1}".format(c_index,len(info_)))
|
||||||
|
try:
|
||||||
|
uniq_ind = str(c_index)+'_' if len(list_Queue)>1 else ''
|
||||||
|
generate_report(wr_fname,info_ ,only_body = False, unique_id_prefix = uniq_ind,fsave_format = True, function_list= func_list)
|
||||||
|
print(" consumer {0} saved output into {1}".format(c_index,wr_fname))
|
||||||
|
except Exception as e:
|
||||||
|
print(traceback.format_exc())
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def obtain_info_from(input_):
|
||||||
|
info_ = dict()
|
||||||
|
for line in input_:
|
||||||
|
x = mtch.match(line)
|
||||||
|
external_source = True
|
||||||
|
if x:
|
||||||
|
file_name =x.group(1).strip()
|
||||||
|
if file_dir_strip in file_name:
|
||||||
|
file_name = file_name.replace(file_dir_strip,'')
|
||||||
|
external_source = False
|
||||||
|
line_number = int(x.group(2))
|
||||||
|
msg = x.group(4).lower()
|
||||||
|
msg = msg.replace(file_dir_strip,'./')
|
||||||
|
msg_x = get_msg(msg)
|
||||||
|
if msg_x is None:
|
||||||
|
continue
|
||||||
|
if file_name in info_:
|
||||||
|
#ignore col_number
|
||||||
|
info_[file_name].add(line_number,msg_x)
|
||||||
|
else:
|
||||||
|
#print("{0} {1}".format(file_name,external_source))
|
||||||
|
info_[file_name] = File_Info().add(line_number,msg_x)
|
||||||
|
info_[file_name].external = external_source
|
||||||
|
elif progress_msg.match(line):
|
||||||
|
#actually we redirect only, stderr so this should not happen
|
||||||
|
print("__"+line.strip())
|
||||||
|
elif "error" in line or "Error" in line:
|
||||||
|
print("****"+line.strip())
|
||||||
|
return info_
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def custom_style(fsave):
|
||||||
|
st = '''<style>a{color:blue;}
|
||||||
|
a:link{text-decoration:none}a:visited{text-decoration:none}a:hover{cursor:pointer;text-decoration:underline}
|
||||||
|
a:active{text-decoration:underline}
|
||||||
|
.f.ext{display:none}
|
||||||
|
.f{color:#000;display:flex;overflow:hidden;justify-content:space-between;flex-wrap:wrap;align-items:baseline;width:100%}
|
||||||
|
.f>div{min-width:10%}.f>div:first-child{min-width:70%;text-overflow:ellipsis}
|
||||||
|
.f:nth-of-type(even){background-color:#f5f5f5}
|
||||||
|
.f>div.g{flex:0 0 100%}.f>div:nth-child(2){font-weight:600;color:green}
|
||||||
|
.f>div:nth-child(3){font-weight:600;color:red}
|
||||||
|
.f>div:nth-child(2)::after{content:' ✓';color:green}.f>div:nth-child(3)::after{content:' -';color:red}
|
||||||
|
.f>div.g>div>div:nth-child(2){font-weight:600;color:green}
|
||||||
|
.f>div.g>div>div:nth-child(3){font-weight:600;color:red}
|
||||||
|
.f>div.g>div>div:nth-child(2)::after{content:' ✓';color:green}
|
||||||
|
.f>div.g>div>div:nth-child(3)::after{content:' -';color:red}
|
||||||
|
.f>div.g>div{display:flex;justify-content:space-between;flex-wrap:wrap;align-items:baseline}
|
||||||
|
.f>div.g>div>div{min-width:10%;text-align:left}
|
||||||
|
.g>div:nth-of-type(even){background-color:#ede6fa}
|
||||||
|
.f>div.g>div>ul{flex:0 0 100%}input[type=checkbox]{opacity:0;display:none}label{cursor:pointer}
|
||||||
|
.f>label{color:red}input[type=checkbox]~.g{display:none}input[type=checkbox]:checked~.g{display:block}
|
||||||
|
input[type=checkbox]~ul{display:none}
|
||||||
|
input[type=checkbox]:checked~ul{display:block}input[type=checkbox]+label::after{content:"⇲";display:block}
|
||||||
|
input[type=checkbox]:checked+label::after{content:"⇱";display:block}
|
||||||
|
|
||||||
|
'''
|
||||||
|
if fsave==True:
|
||||||
|
st+='''.modal{display:none;height:100%;background-color:#144F84;color:#fff;opacity:.93;left:0;position:fixed;top:0;width:100%}
|
||||||
|
.modal.open{display:flex;flex-direction:column}.modal__header{height:auto;font-size:large;padding:10px;background-color:#000;color:#fff}
|
||||||
|
.modal__footer{height:auto;font-size:medium;background-color:#000}
|
||||||
|
.modal__content{height:100%;display:flex;flex-direction:column;padding:20px;overflow-y:auto}
|
||||||
|
.modal_close{cursor:pointer;float:right}li{cursor:pointer}
|
||||||
|
'''
|
||||||
|
return st + '''</style>'''
|
||||||
|
|
||||||
|
def header(fsave=False):
|
||||||
|
strx ='<!DOCTYPE html>\n<html>\n<head>\n<meta charset="UTF-8">\n<title>Auto-Vectorization</title>\n'
|
||||||
|
strx +='<base id="base_id" href="{0}" target="_blank" >'.format(BASE_URL)
|
||||||
|
strx +=custom_style(fsave)
|
||||||
|
strx +='\n</head>\n<body>\n'
|
||||||
|
return strx
|
||||||
|
|
||||||
|
def footer():
|
||||||
|
return '\n</body></html>'
|
||||||
|
|
||||||
|
|
||||||
|
def get_compressed_indices(set_a):
|
||||||
|
a_len = len(set_a)
|
||||||
|
if a_len<=1:
|
||||||
|
if a_len<1:
|
||||||
|
return ''
|
||||||
|
return str(set_a)[1:-1]
|
||||||
|
#we sorted and only saved difference
|
||||||
|
# 1,14,15,19 --> 1,13,1,4 10bytes=>8bytes
|
||||||
|
list_sorted = sorted(list(set_a))
|
||||||
|
last = list_sorted[0]
|
||||||
|
str_x = str(list_sorted[0])
|
||||||
|
for i in range(1,a_len):
|
||||||
|
str_x += ','+str(list_sorted[i]-last)
|
||||||
|
last = list_sorted[i]
|
||||||
|
return str_x
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def get_content(k, v, unique_id_prefix = '', fsave_format=False):
|
||||||
|
inner_str=''
|
||||||
|
content = ''
|
||||||
|
inc_id = 0
|
||||||
|
for fk,fv in sorted(v.infos.items()):
|
||||||
|
if fsave_format==True:
|
||||||
|
inner_str+='<div><div><a>{0}</a></div><div>{1}</div><div>{2}</div><input type="checkbox" id="c{3}{4}"><label for="c{3}{4}"></label><ul>'.format(
|
||||||
|
fk,fv.optimized,fv.missed,unique_id_prefix,inc_id)
|
||||||
|
else:
|
||||||
|
inner_str+='<div><div><a href=".{0}#L{1}">{1}</a></div><div>{2}</div><div>{3}</div><input type="checkbox" id="c{4}{5}"><label for="c{4}{5}"></label><ul>'.format(
|
||||||
|
k,fk,fv.optimized,fv.missed,unique_id_prefix,inc_id)
|
||||||
|
inc_id+=1
|
||||||
|
if fsave_format==True:
|
||||||
|
#
|
||||||
|
for dt,df in fv.miss_details2.items():
|
||||||
|
#inner_str +='<li data-fns="{0}">{1}</li>'.format(str(df).replace(", ",",")[1:-1],dt)
|
||||||
|
inner_str +='<li data-fns="{0}">{1}</li>'.format(get_compressed_indices(df),dt)
|
||||||
|
else:
|
||||||
|
for dt in fv.miss_details:
|
||||||
|
inner_str+="<li>"+str(dt)+ "</li>"
|
||||||
|
inner_str+="</ul></div>\n"
|
||||||
|
|
||||||
|
content += '<div class="f'
|
||||||
|
if v.external:
|
||||||
|
content += " ext"
|
||||||
|
content += '">\n<div>{0}</div><div>{1}</div><div>{2}</div><input type="checkbox" id="i{3}{4}"><label for="i{3}{4}"></label>'.format(
|
||||||
|
k,v.total_opted,v.total_missed,unique_id_prefix,inc_id)
|
||||||
|
content += "<div class='g'>"
|
||||||
|
content += inner_str
|
||||||
|
content += "</div> </div>\n"
|
||||||
|
return content
|
||||||
|
|
||||||
|
|
||||||
|
def jscript_head():
|
||||||
|
return '''
|
||||||
|
window.onload = function () {
|
||||||
|
var modal = document.getElementsByClassName("modal")[0];
|
||||||
|
var modal_close = document.getElementsByClassName("modal_close")[0];
|
||||||
|
var content = document.getElementsByClassName("modal__content")[0];
|
||||||
|
a_tags = document.getElementsByTagName("a");
|
||||||
|
base_href = document.getElementById("base_id").href;
|
||||||
|
for(i=0;i<a_tags.length;i++){
|
||||||
|
a_tags[i].addEventListener("click", function () {
|
||||||
|
var source = event.target || event.srcElement;
|
||||||
|
file_src = source.parentElement.parentElement.parentElement.parentElement.children[0].innerText ;
|
||||||
|
link = base_href + file_src+'#L'+ source.innerText;
|
||||||
|
window.open(link, '_blank');
|
||||||
|
|
||||||
|
});
|
||||||
|
}
|
||||||
|
modal_close.addEventListener("click", function () {
|
||||||
|
content.innerHTML = '';
|
||||||
|
modal.className = 'modal';
|
||||||
|
});
|
||||||
|
|
||||||
|
'''
|
||||||
|
def jscipt_end():
|
||||||
|
return '''
|
||||||
|
tags = document.getElementsByTagName("li");
|
||||||
|
function escapeHtml(unsafe) {
|
||||||
|
return unsafe
|
||||||
|
.replace(/&/g, "&")
|
||||||
|
.replace(/</g, "<")
|
||||||
|
.replace(/>/g, ">")
|
||||||
|
.replace(/"/g, """)
|
||||||
|
.replace(/'/g, "'");
|
||||||
|
}
|
||||||
|
for (i = 0; i < tags.length; i++) {
|
||||||
|
tags[i].addEventListener("click", function () {
|
||||||
|
var source = event.target || event.srcElement;
|
||||||
|
funcs = source.dataset.fns.split(",")
|
||||||
|
strx = ''
|
||||||
|
//we saved differences,not real indices
|
||||||
|
last_ind = 0;
|
||||||
|
for (j = 0; j < funcs.length; j++) {
|
||||||
|
ind = last_ind + parseInt(funcs[j]);
|
||||||
|
strx += "<p>" + escapeHtml(func_list[ind]) + "</p>";
|
||||||
|
last_ind = ind;
|
||||||
|
}
|
||||||
|
if (strx.length > 0) {
|
||||||
|
content.innerHTML = strx;
|
||||||
|
modal.className = 'modal open';
|
||||||
|
}
|
||||||
|
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
};'''
|
||||||
|
|
||||||
|
def additional_tags(fsave):
|
||||||
|
if fsave==False:
|
||||||
|
return ''
|
||||||
|
#
|
||||||
|
return '''<script type='text/javascript'>
|
||||||
|
var script = document.createElement('script'); script.src = window.location.href+".js" ;
|
||||||
|
document.head.appendChild(script);
|
||||||
|
</script>
|
||||||
|
<div class="modal">
|
||||||
|
<div class="modal__header">Functions <span class="modal_close">X</span></div>
|
||||||
|
<div class="modal__content"></div>
|
||||||
|
<div class="modal__footer">========</div>
|
||||||
|
</div>
|
||||||
|
'''
|
||||||
|
|
||||||
|
def generate_report(output_name,info_ ,only_body = False, unique_id_prefix='',fsave_format = False , function_list = None):
|
||||||
|
'''
|
||||||
|
Generate Auto-Vectorization Report in html format
|
||||||
|
'''
|
||||||
|
|
||||||
|
temp_str =''
|
||||||
|
if fsave_format ==True:
|
||||||
|
# we gonna dump function_list as key list sorted by value
|
||||||
|
#and use it as jscript array
|
||||||
|
sorted_funcs_by_index = sorted(function_list.items(), key=lambda x: x[1])
|
||||||
|
del function_list
|
||||||
|
with open(output_name+ ".js","w") as f:
|
||||||
|
#temp_str =jscript_head() +'{ "fmaps":['
|
||||||
|
temp_str = jscript_head() + "\n var func_list = ["
|
||||||
|
for k,v in sorted_funcs_by_index:
|
||||||
|
#json.dumps using for escape
|
||||||
|
#print(str(v)+str(k))
|
||||||
|
temp_str+=json.dumps(get_cxx_filt_result(k))+","
|
||||||
|
#reduce write calls
|
||||||
|
if len(temp_str)>8192*2:
|
||||||
|
f.write(temp_str)
|
||||||
|
temp_str= ''
|
||||||
|
if len(temp_str)>0:
|
||||||
|
f.write(temp_str)
|
||||||
|
f.write('"-"];'+jscipt_end())
|
||||||
|
|
||||||
|
|
||||||
|
temp_str = ''
|
||||||
|
with open(output_name,"w") as f:
|
||||||
|
if only_body==False:
|
||||||
|
f.write(header(fsave_format))
|
||||||
|
f.write(additional_tags(fsave_format))
|
||||||
|
nm=0
|
||||||
|
for k,v in sorted(info_.items()): # sorted(info_.items(), key=lambda x: x[1].total_opted, reverse=True):
|
||||||
|
temp_str += get_content(k,v,unique_id_prefix+str(nm),fsave_format)
|
||||||
|
#reduce io write calls
|
||||||
|
if len(temp_str)>8192:
|
||||||
|
f.write(temp_str)
|
||||||
|
temp_str =''
|
||||||
|
nm+=1
|
||||||
|
if len(temp_str)>0:
|
||||||
|
f.write(temp_str)
|
||||||
|
if only_body==False:
|
||||||
|
f.write(footer())
|
||||||
|
|
||||||
|
|
||||||
|
def fsave_report_launch(json_gz_list):
|
||||||
|
|
||||||
|
cpus = cpu_count()
|
||||||
|
if cpus>32:
|
||||||
|
cpus = 24
|
||||||
|
|
||||||
|
c_count = 1 # 2 i sufficient # if cpus<=1 else min(4,cpus)
|
||||||
|
p_count = 3 if cpus<=1 else max(8, cpus - c_count)
|
||||||
|
|
||||||
|
m = Manager()
|
||||||
|
#consumer Queues
|
||||||
|
list_Queue = [m.Queue() for index in range(0,c_count)]
|
||||||
|
with Pool(processes=c_count) as consumers:
|
||||||
|
#start consumers
|
||||||
|
cs = consumers.map_async(consume_processed_mp,[(list_Queue, index,) for index in range(0,c_count)])
|
||||||
|
with Pool(processes=p_count) as processors:
|
||||||
|
processors.map(process_gzip_json_mp, [(fname, list_Queue,) for fname in json_gz_list])
|
||||||
|
|
||||||
|
#send ends to inform our consumers
|
||||||
|
#send ends
|
||||||
|
for q in list_Queue:
|
||||||
|
q.put(None)
|
||||||
|
|
||||||
|
#wait for consumers
|
||||||
|
cs.wait()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
if "--fsave" in sys.argv:
|
||||||
|
json_gz_list = internal_glob(".","*.json.gz")
|
||||||
|
fsave_report_launch(json_gz_list)
|
||||||
|
return
|
||||||
|
|
||||||
|
file_info = obtain_info_from(sys.stdin)
|
||||||
|
if len(file_info)>0:
|
||||||
|
#print(file_info)
|
||||||
|
print("---generating vectorization html report--")
|
||||||
|
generate_report("vecmiss.html", file_info)
|
||||||
|
else:
|
||||||
|
# lets check if we got fsave files
|
||||||
|
json_gz_list = internal_glob(".","*.json.gz")
|
||||||
|
fsave_report_launch(json_gz_list)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
|
@ -0,0 +1,354 @@
|
||||||
|
'''
|
||||||
|
@author : Abdelrauf rauf@konduit.ai
|
||||||
|
Simple object xtractor form very big json files
|
||||||
|
'''
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from cpython.mem cimport PyMem_Malloc, PyMem_Realloc, PyMem_Free
|
||||||
|
|
||||||
|
|
||||||
|
cdef char JSON_1 = b':'
|
||||||
|
cdef char JSON_2 = b','
|
||||||
|
cdef char JSON_3 = b'{'
|
||||||
|
cdef char JSON_4 = b'}'
|
||||||
|
cdef char JSON_5 = b'['
|
||||||
|
cdef char JSON_6 = b']'
|
||||||
|
cdef char QUOTE = b'"'
|
||||||
|
cdef char ESCAPE = b"\\"
|
||||||
|
cdef char SPACE = b' '
|
||||||
|
cdef char TAB = b't'
|
||||||
|
cdef char CR = b'\r'
|
||||||
|
cdef char NL = b'\n'
|
||||||
|
cdef char B = b'\b'
|
||||||
|
cdef char EMPTY = b'\0'
|
||||||
|
|
||||||
|
|
||||||
|
cdef struct Span:
|
||||||
|
int b
|
||||||
|
int e
|
||||||
|
|
||||||
|
cdef inline Span read_unquoted(char *text, int start,int end):
|
||||||
|
cdef Span sp
|
||||||
|
cdef int j = start
|
||||||
|
while j < end:
|
||||||
|
#if text[j].isspace():
|
||||||
|
if text[j] == SPACE or text[j] == NL or text[j] == TAB or text[j] == CR or text[j] == B:
|
||||||
|
j += 1
|
||||||
|
continue
|
||||||
|
if text[j] != QUOTE and text[j] != JSON_1 and text[j] != JSON_2 and text[j] != JSON_3 and text[j] != JSON_4 and text[j] != JSON_5 and text[j] != JSON_6:
|
||||||
|
start = j
|
||||||
|
j += 1
|
||||||
|
while j < end:
|
||||||
|
# read till JSON or white space
|
||||||
|
if text[j] == SPACE or text[j] == NL or text[j] == TAB or text[j] == CR or text[j] == B:
|
||||||
|
sp.b = start
|
||||||
|
sp.e = j
|
||||||
|
return sp
|
||||||
|
elif text[j] == JSON_1 or text[j] == JSON_2 or text[j] == JSON_3 or text[j] == JSON_4 or text[j] == JSON_5 or text[j] == JSON_6:
|
||||||
|
sp.b = start
|
||||||
|
sp.e = j
|
||||||
|
return sp
|
||||||
|
j += 1
|
||||||
|
if j == end-1:
|
||||||
|
sp.b = start
|
||||||
|
sp.e = end
|
||||||
|
return sp
|
||||||
|
break
|
||||||
|
sp.b = j
|
||||||
|
sp.e = j
|
||||||
|
return sp
|
||||||
|
|
||||||
|
|
||||||
|
cdef inline Span read_seq_token(char *text,int start,int end):
|
||||||
|
#read quoted
|
||||||
|
#skip white_space
|
||||||
|
cdef Span sp
|
||||||
|
cdef int j = start
|
||||||
|
cdef char last_char
|
||||||
|
cdef char char_x
|
||||||
|
while j < end:
|
||||||
|
if text[j] == SPACE or text[j] == NL or text[j] == TAB or text[j] == CR or text[j] == B:
|
||||||
|
j += 1
|
||||||
|
continue
|
||||||
|
if text[j] == QUOTE:
|
||||||
|
last_char = EMPTY
|
||||||
|
#read till another quote
|
||||||
|
start = j
|
||||||
|
j += 1
|
||||||
|
while j < end:
|
||||||
|
char_x = text[j]
|
||||||
|
if char_x == QUOTE and last_char != ESCAPE:
|
||||||
|
# finished reading
|
||||||
|
sp.b =start
|
||||||
|
sp.e = j+1
|
||||||
|
return sp
|
||||||
|
last_char = char_x
|
||||||
|
j += 1
|
||||||
|
if j == end-1:
|
||||||
|
sp.b = start
|
||||||
|
sp.e = end
|
||||||
|
return sp
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
return read_unquoted(text, j, end)
|
||||||
|
|
||||||
|
|
||||||
|
def tokenizer_spans(utext):
|
||||||
|
'''
|
||||||
|
we will just return tokenize spans
|
||||||
|
'''
|
||||||
|
token_spans = []
|
||||||
|
last_char = b''
|
||||||
|
end_i = len(utext)
|
||||||
|
cdef char *text = utext
|
||||||
|
i = 0
|
||||||
|
cdef Span sp
|
||||||
|
while i < end_i:
|
||||||
|
sp = read_seq_token(text, i, end_i)
|
||||||
|
i = sp.e
|
||||||
|
if sp.e > sp.b:
|
||||||
|
token_spans.append((sp.b, sp.e))
|
||||||
|
if i < end_i:
|
||||||
|
#if text[i] in JSON:
|
||||||
|
if text[i] == JSON_3 or text[i] == JSON_4 or text[i] == JSON_5 or text[i] == JSON_6 or text[i] == JSON_1 or text[i] == JSON_2:
|
||||||
|
token_spans.append((i, i+1))
|
||||||
|
i += 1
|
||||||
|
return token_spans
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
cdef class JsonObjXtractor:
|
||||||
|
'''
|
||||||
|
JsonObjXtractor that utilize cython better
|
||||||
|
'''
|
||||||
|
|
||||||
|
cdef Span* token_spans
|
||||||
|
cdef size_t size
|
||||||
|
|
||||||
|
def __cinit__(self, size_t count=4096):
|
||||||
|
self.token_spans = <Span*> PyMem_Malloc(count * sizeof(Span))
|
||||||
|
self.size = count
|
||||||
|
if not self.token_spans:
|
||||||
|
raise MemoryError()
|
||||||
|
|
||||||
|
|
||||||
|
def __tokenizer_spans(self,utext, length):
|
||||||
|
'''
|
||||||
|
we will just return token spans length
|
||||||
|
'''
|
||||||
|
|
||||||
|
last_char = b''
|
||||||
|
end_i = length
|
||||||
|
cdef char *text = utext
|
||||||
|
cdef int i = 0
|
||||||
|
cdef size_t j = 0
|
||||||
|
cdef Span sp
|
||||||
|
while i < end_i:
|
||||||
|
sp = read_seq_token(text, i, end_i)
|
||||||
|
i = sp.e
|
||||||
|
if sp.e > sp.b:
|
||||||
|
self.token_spans[j] = sp
|
||||||
|
j+=1
|
||||||
|
if j>self.size:
|
||||||
|
#we need to reallocate
|
||||||
|
self.__resize(self.size+self.size//2)
|
||||||
|
if i < end_i:
|
||||||
|
#if text[i] in JSON:
|
||||||
|
if text[i] == JSON_3 or text[i] == JSON_4 or text[i] == JSON_5 or text[i] == JSON_6 or text[i] == JSON_1 or text[i] == JSON_2:
|
||||||
|
sp.b=i
|
||||||
|
sp.e=i+1
|
||||||
|
self.token_spans[j] = sp
|
||||||
|
j+=1
|
||||||
|
if j>self.size:
|
||||||
|
#we need to reallocate
|
||||||
|
self.__resize(self.size+self.size//2)
|
||||||
|
i += 1
|
||||||
|
return j
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def try_extract_parent_obj(self, json_bytes, property_name, next_contains_value=b'', debug=False):
|
||||||
|
'''
|
||||||
|
try_extract_parent_obj(json_text, property_name, next_contains_value='', debug=False):
|
||||||
|
make sure that passed variables encoded to bytes with encode('utf-8')
|
||||||
|
next_contains_value either direct content or followed by '['
|
||||||
|
tries to extract the parent object for given named object
|
||||||
|
if the left brace of the parent object is outside of the current buffer
|
||||||
|
it will be ignored
|
||||||
|
if the right brace is outside of the buffer it will be left to be handled by caller
|
||||||
|
'''
|
||||||
|
|
||||||
|
look_for_the_left = True
|
||||||
|
parent_left = []
|
||||||
|
parent_right = []
|
||||||
|
parent_objects = []
|
||||||
|
len_next = len(next_contains_value)
|
||||||
|
cdef int ind = 0
|
||||||
|
cdef int end
|
||||||
|
cdef int last_start = 0
|
||||||
|
property_name = b'"'+property_name+b'"'
|
||||||
|
cdef int lenx = self.__tokenizer_spans(json_bytes,len(json_bytes))
|
||||||
|
cdef char x
|
||||||
|
cdef int i = -1
|
||||||
|
cdef Span sp
|
||||||
|
while i < lenx-1:
|
||||||
|
i += 1
|
||||||
|
ind = self.token_spans[i].b
|
||||||
|
x = json_bytes[ind]
|
||||||
|
#print("-----{0} -- {1} -- {2} ".format(x,parent_left,parent_right))
|
||||||
|
if look_for_the_left == False:
|
||||||
|
if x == JSON_3:
|
||||||
|
parent_right.append(ind)
|
||||||
|
elif x == JSON_4:
|
||||||
|
if len(parent_right) == 0:
|
||||||
|
#we found parent closing brace
|
||||||
|
look_for_the_left = True
|
||||||
|
parent_objects.append((parent_left[-1], ind+1))
|
||||||
|
last_start = ind+1
|
||||||
|
#print("=============found {0}".format(parent_objects))
|
||||||
|
parent_left = []
|
||||||
|
parent_right = []
|
||||||
|
else:
|
||||||
|
parent_right.pop()
|
||||||
|
continue
|
||||||
|
#search obj
|
||||||
|
if look_for_the_left:
|
||||||
|
if x == JSON_3:
|
||||||
|
parent_left.append(ind)
|
||||||
|
last_start = ind
|
||||||
|
elif x == JSON_4:
|
||||||
|
if len(parent_left) >= 1:
|
||||||
|
#ignore
|
||||||
|
parent_left.pop()
|
||||||
|
|
||||||
|
if x == JSON_1: # ':'
|
||||||
|
#check to see if propertyname
|
||||||
|
old_property = EMPTY
|
||||||
|
if i > 1:
|
||||||
|
sp = self.token_spans[i-1]
|
||||||
|
old_property = json_bytes[sp.b:sp.e]
|
||||||
|
if old_property == property_name:
|
||||||
|
#we found
|
||||||
|
if len(parent_left) < 1:
|
||||||
|
#left brace is outside of the buffer
|
||||||
|
#we have to ignore it
|
||||||
|
#try to increase buffer
|
||||||
|
if debug:
|
||||||
|
print('''left brace of the parent is outside of the buffer and parent is big.
|
||||||
|
it will be ignored
|
||||||
|
try to choose disambiguous property names if you are looking for small objects''', file=sys.stderr)
|
||||||
|
last_start = ind+1
|
||||||
|
parent_left = []
|
||||||
|
parent_right = []
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
#print("++++++ look for the right brace")
|
||||||
|
if len_next>0 and i+1 < lenx:
|
||||||
|
i += 1
|
||||||
|
ind = self.token_spans[i].b
|
||||||
|
end = self.token_spans[i].e
|
||||||
|
m = json_bytes[ind]
|
||||||
|
|
||||||
|
if m == JSON_5:
|
||||||
|
#print ("----{0} {1}".format(m,JSON_5))
|
||||||
|
if i+1 < lenx:
|
||||||
|
i += 1
|
||||||
|
ind = self.token_spans[i].b
|
||||||
|
end = self.token_spans[i].e
|
||||||
|
#print ("----{0} == {1}".format(next_contains_value,json_bytes[ind:end]))
|
||||||
|
if len_next <= end-ind and next_contains_value in json_bytes[ind:end]:
|
||||||
|
look_for_the_left = False
|
||||||
|
continue
|
||||||
|
elif len_next <= end-ind and next_contains_value in json_bytes[ind:end]:
|
||||||
|
look_for_the_left = False
|
||||||
|
continue
|
||||||
|
|
||||||
|
#ignore as it does not have that value
|
||||||
|
parent_left = []
|
||||||
|
parent_right = []
|
||||||
|
last_start = ind + 1
|
||||||
|
else:
|
||||||
|
look_for_the_left = False
|
||||||
|
|
||||||
|
# lets return last succesful opened brace as the last
|
||||||
|
# or left brace failure case, safe closed brace
|
||||||
|
if len(parent_left)>0:
|
||||||
|
return (parent_objects, parent_left[-1])
|
||||||
|
|
||||||
|
return (parent_objects, last_start)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def __resize(self, size_t new_count):
|
||||||
|
cdef Span* mem = <Span*> PyMem_Realloc(self.token_spans, new_count * sizeof(Span))
|
||||||
|
if not mem:
|
||||||
|
raise MemoryError()
|
||||||
|
self.token_spans = mem
|
||||||
|
self.size = new_count
|
||||||
|
|
||||||
|
def __dealloc__(self):
|
||||||
|
PyMem_Free(self.token_spans)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
import json
|
||||||
|
import gzip
|
||||||
|
import sys
|
||||||
|
DEBUG_LOG = False
|
||||||
|
|
||||||
|
def json_gzip_extract_objects(filename, property_name, next_contains_value=''):
|
||||||
|
strx = b''
|
||||||
|
started= False
|
||||||
|
b_next_contains_value = next_contains_value.encode('utf-8')
|
||||||
|
b_property_name = property_name.encode('utf-8')
|
||||||
|
#print(b_property_name)
|
||||||
|
objXt = JsonObjXtractor()
|
||||||
|
with gzip.open(filename, 'rb') as f:
|
||||||
|
if DEBUG_LOG:
|
||||||
|
print("opened {0}".format(filename), file=sys.stderr)
|
||||||
|
#instead of reading it as line, I will read it as binary bytes
|
||||||
|
is_End = False
|
||||||
|
#total = 0
|
||||||
|
while is_End==False:
|
||||||
|
buffer = f.read(8192*2)
|
||||||
|
|
||||||
|
lenx= len(buffer)
|
||||||
|
#total +=lenx
|
||||||
|
if lenx<1:
|
||||||
|
is_End = True
|
||||||
|
else:
|
||||||
|
strx = strx + buffer
|
||||||
|
|
||||||
|
objects , last_index = objXt.try_extract_parent_obj(strx,b_property_name,b_next_contains_value)
|
||||||
|
|
||||||
|
# if b_property_name in strx and b_next_contains_value in strx:
|
||||||
|
# print(strx)
|
||||||
|
# print(objects)
|
||||||
|
# print(last_index)
|
||||||
|
# print("===================================================")
|
||||||
|
|
||||||
|
for start,end in objects:
|
||||||
|
yield json.loads(strx[start:end]) #.decode('utf-8'))
|
||||||
|
|
||||||
|
|
||||||
|
#remove processed
|
||||||
|
if last_index< len(strx):
|
||||||
|
strx = strx[last_index:]
|
||||||
|
|
||||||
|
else:
|
||||||
|
strx = b''
|
||||||
|
#print('----+++')
|
||||||
|
|
||||||
|
if(len(strx)>16384*3):
|
||||||
|
#buffer to big
|
||||||
|
#try to avoid big parents
|
||||||
|
if DEBUG_LOG:
|
||||||
|
print("parent object is too big. please, look for better property name", file=sys.stderr)
|
||||||
|
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,3 @@
|
||||||
|
from distutils.core import setup
|
||||||
|
from Cython.Build import cythonize
|
||||||
|
setup(ext_modules=cythonize("bigGzipJson.pyx", language_level="3"))
|
|
@ -282,6 +282,32 @@ elseif(CPU_BLAS)
|
||||||
set_source_files_properties(../include/helpers/impl/OpTracker.cpp PROPERTIES COMPILE_FLAGS "-march=x86-64 -mtune=generic")
|
set_source_files_properties(../include/helpers/impl/OpTracker.cpp PROPERTIES COMPILE_FLAGS "-march=x86-64 -mtune=generic")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
if(CHECK_VECTORIZATION)
|
||||||
|
set(VECT_FILES cpu/NativeOps.cpp ${OPS_SOURCES} ${HELPERS_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES} ${LOOPS_SOURCES})
|
||||||
|
if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU")
|
||||||
|
|
||||||
|
if (CMAKE_COMPILER_IS_GNUCC AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 9.0)
|
||||||
|
set(CHECK_VECT_FLAGS "-ftree-vectorize -fsave-optimization-record")
|
||||||
|
#to process fsave-optimization-record we will need our cython version code
|
||||||
|
message("Build Auto vectorization helpers")
|
||||||
|
execute_process(COMMAND "python3" "${CMAKE_CURRENT_SOURCE_DIR}/../auto_vectorization/cython_setup.py" "build_ext" "--inplace" WORKING_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/../auto_vectorization/" RESULT_VARIABLE ret)
|
||||||
|
message("build='${ret}'")
|
||||||
|
|
||||||
|
#remove fail cases that gcc fails produce sometimes
|
||||||
|
file(GLOB_RECURSE FAILURE_CASES false ../include/loops/cpu/compilation_units/reduce3*.cpp)
|
||||||
|
#message("*****${FAILURE_CASES}")
|
||||||
|
foreach(FL_ITEM ${FAILURE_CASES})
|
||||||
|
message("Removing failure cases ${FL_ITEM}")
|
||||||
|
list(REMOVE_ITEM VECT_FILES ${FL_ITEM})
|
||||||
|
endforeach()
|
||||||
|
else()
|
||||||
|
set(CHECK_VECT_FLAGS "-ftree-vectorize -fopt-info-vec-optimized-missed")
|
||||||
|
endif()
|
||||||
|
message("CHECK VECTORIZATION ${CHECK_VECT_FLAGS}")
|
||||||
|
set_source_files_properties( ${VECT_FILES} PROPERTIES COMPILE_FLAGS "${CHECK_VECT_FLAGS}" )
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
message("CPU BLAS")
|
message("CPU BLAS")
|
||||||
add_definitions(-D__CPUBLAS__=true)
|
add_definitions(-D__CPUBLAS__=true)
|
||||||
add_library(nd4jobj OBJECT cpu/NativeOps.cpp cpu/GraphExecutioner.cpp
|
add_library(nd4jobj OBJECT cpu/NativeOps.cpp cpu/GraphExecutioner.cpp
|
||||||
|
|
|
@ -195,6 +195,56 @@ namespace nd4j {
|
||||||
|
|
||||||
NDArray(std::shared_ptr<DataBuffer> buffer, const char order, const std::vector<Nd4jLong> &shape, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
NDArray(std::shared_ptr<DataBuffer> buffer, const char order, const std::vector<Nd4jLong> &shape, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This contructors create scalar array containing string utf8
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
NDArray(const char* str, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext())
|
||||||
|
: NDArray(std::string(str), dtype, context) {
|
||||||
|
}
|
||||||
|
NDArray(const std::string& string, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This contructors create scalar array containing string utf16
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
NDArray(const char16_t* u16string, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext())
|
||||||
|
: NDArray(std::u16string(u16string), dtype, context) {
|
||||||
|
}
|
||||||
|
|
||||||
|
NDArray(const std::u16string& u16string, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This contructors create scalar array containing string utf32
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
NDArray(const char32_t* u32string, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext())
|
||||||
|
: NDArray(std::u32string(u32string), dtype, context) {
|
||||||
|
}
|
||||||
|
|
||||||
|
NDArray(const std::u32string& u32string, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This contructors create array from vector of utf8 strings
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
NDArray(const std::vector<Nd4jLong>& shape, const std::vector<const char*>& strings, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||||
|
NDArray(const std::vector<Nd4jLong>& shape, const std::vector<std::string>& string, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This contructors create array from vector of utf16 strings
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
NDArray(const std::vector<Nd4jLong>& shape, const std::vector<const char16_t*>& strings, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||||
|
NDArray(const std::vector<Nd4jLong>& shape, const std::vector<std::u16string>& string, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This contructors create array from vector of utf32 strings
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
NDArray(const std::vector<Nd4jLong>& shape, const std::vector<const char32_t*>& strings, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||||
|
NDArray(const std::vector<Nd4jLong>& shape, const std::vector<std::u32string>& string, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -250,7 +300,6 @@ namespace nd4j {
|
||||||
*/
|
*/
|
||||||
NDArray(void *buffer, const char order, const std::vector<Nd4jLong> &shape, nd4j::DataType dtype, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext(), const bool isBuffAlloc = false);
|
NDArray(void *buffer, const char order, const std::vector<Nd4jLong> &shape, nd4j::DataType dtype, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext(), const bool isBuffAlloc = false);
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method returns new array with the same shape & data type
|
* This method returns new array with the same shape & data type
|
||||||
* @return
|
* @return
|
||||||
|
@ -1148,6 +1197,9 @@ namespace nd4j {
|
||||||
template <typename N>
|
template <typename N>
|
||||||
NDArray asT() const;
|
NDArray asT() const;
|
||||||
|
|
||||||
|
template <typename S>
|
||||||
|
NDArray asS() const;
|
||||||
|
|
||||||
NDArray asT(DataType dtype) const;
|
NDArray asT(DataType dtype) const;
|
||||||
|
|
||||||
|
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2019-2020 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -16,6 +17,7 @@
|
||||||
|
|
||||||
//
|
//
|
||||||
// Created by raver119 on 2018-09-16.
|
// Created by raver119 on 2018-09-16.
|
||||||
|
// @author Oleg Semeniv <oleg.semeniv@gmail.com>
|
||||||
//
|
//
|
||||||
|
|
||||||
#ifndef DEV_TESTS_NDARRAYFACTORY_H
|
#ifndef DEV_TESTS_NDARRAYFACTORY_H
|
||||||
|
@ -106,25 +108,72 @@ namespace nd4j {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static NDArray create(char order, const std::vector<Nd4jLong> &shape, const std::initializer_list<T>& data, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
static NDArray create(char order, const std::vector<Nd4jLong> &shape, const std::initializer_list<T>& data, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
||||||
|
|
||||||
static NDArray string(const char *string, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
/**
|
||||||
|
* This factory create array from utf8 string
|
||||||
|
* @return NDArray default dataType UTF8
|
||||||
|
*/
|
||||||
|
static NDArray string(const char *string, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
||||||
|
static NDArray* string_(const char *string, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
||||||
|
static NDArray* string_(const std::string &string, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
||||||
|
static NDArray string(const std::string& string, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||||
|
|
||||||
static NDArray* string_(const char *string, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
/**
|
||||||
|
* This factory create array from utf16 string
|
||||||
|
* @return NDArray default dataType UTF16
|
||||||
|
*/
|
||||||
|
static NDArray string(const char16_t* u16string, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||||
|
static NDArray* string_(const char16_t* u16string, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||||
|
static NDArray* string_(const std::u16string& u16string, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||||
|
static NDArray string(const std::u16string& u16string, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||||
|
|
||||||
static NDArray string(const std::string &string, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
/**
|
||||||
|
* This factory create array from utf32 string
|
||||||
|
* @return NDArray default dataType UTF32
|
||||||
|
*/
|
||||||
|
static NDArray string(const char32_t* u32string, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||||
|
static NDArray* string_(const char32_t* u32string, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||||
|
static NDArray* string_(const std::u32string& u32string, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||||
|
static NDArray string(const std::u32string& u32string, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||||
|
|
||||||
static NDArray* string_(const std::string &string, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
/**
|
||||||
|
* This factory create array from vector of utf8 strings
|
||||||
|
* @return NDArray default dataType UTF8
|
||||||
|
*/
|
||||||
|
static NDArray string( const std::vector<Nd4jLong> &shape, const std::initializer_list<const char *> &strings, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
||||||
|
static NDArray string( const std::vector<Nd4jLong> &shape, const std::initializer_list<std::string> &string, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
||||||
|
static NDArray string( const std::vector<Nd4jLong> &shape, const std::vector<const char *> &strings, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
||||||
|
static NDArray string( const std::vector<Nd4jLong> &shape, const std::vector<std::string> &string, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
||||||
|
static NDArray* string_( const std::vector<Nd4jLong> &shape, const std::initializer_list<const char *> &strings, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
||||||
|
static NDArray* string_( const std::vector<Nd4jLong> &shape, const std::initializer_list<std::string> &string, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
||||||
|
static NDArray* string_( const std::vector<Nd4jLong> &shape, const std::vector<const char *> &strings, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
||||||
|
static NDArray* string_( const std::vector<Nd4jLong> &shape, const std::vector<std::string> &string, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
||||||
|
|
||||||
static NDArray string(char order, const std::vector<Nd4jLong> &shape, const std::initializer_list<const char *> &strings, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
/**
|
||||||
static NDArray string(char order, const std::vector<Nd4jLong> &shape, const std::initializer_list<std::string> &string, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
* This factory create array from vector of utf16 strings
|
||||||
|
* @return NDArray default dataType UTF16
|
||||||
|
*/
|
||||||
|
static NDArray string( const std::vector<Nd4jLong>& shape, const std::initializer_list<const char16_t*>& strings, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||||
|
static NDArray string( const std::vector<Nd4jLong>& shape, const std::initializer_list<std::u16string>& string, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||||
|
static NDArray string( const std::vector<Nd4jLong>& shape, const std::vector<const char16_t*>& strings, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||||
|
static NDArray string( const std::vector<Nd4jLong>& shape, const std::vector<std::u16string>& string, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||||
|
static NDArray* string_( const std::vector<Nd4jLong>& shape, const std::initializer_list<const char16_t*>& strings, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||||
|
static NDArray* string_( const std::vector<Nd4jLong>& shape, const std::initializer_list<std::u16string>& string, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||||
|
static NDArray* string_( const std::vector<Nd4jLong>& shape, const std::vector<const char16_t*>& strings, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||||
|
static NDArray* string_( const std::vector<Nd4jLong>& shape, const std::vector<std::u16string>& string, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||||
|
|
||||||
static NDArray string(char order, const std::vector<Nd4jLong> &shape, const std::vector<const char *> &strings, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
/**
|
||||||
static NDArray string(char order, const std::vector<Nd4jLong> &shape, const std::vector<std::string> &string, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
* This factory create array from vector of utf32 strings
|
||||||
|
* @return NDArray default dataType UTF32
|
||||||
|
*/
|
||||||
|
static NDArray string( const std::vector<Nd4jLong>& shape, const std::initializer_list<const char32_t*>& strings, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||||
|
static NDArray string( const std::vector<Nd4jLong>& shape, const std::initializer_list<std::u32string>& string, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||||
|
static NDArray string( const std::vector<Nd4jLong>& shape, const std::vector<const char32_t*>& strings, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||||
|
static NDArray string( const std::vector<Nd4jLong>& shape, const std::vector<std::u32string>& string, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||||
|
static NDArray* string_( const std::vector<Nd4jLong>& shape, const std::initializer_list<const char32_t*>& strings, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||||
|
static NDArray* string_( const std::vector<Nd4jLong>& shape, const std::initializer_list<std::u32string>& string, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||||
|
static NDArray* string_( const std::vector<Nd4jLong>& shape, const std::vector<const char32_t*>& strings, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||||
|
static NDArray* string_( const std::vector<Nd4jLong>& shape, const std::vector<std::u32string>& string, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||||
|
|
||||||
static NDArray* string_(char order, const std::vector<Nd4jLong> &shape, const std::initializer_list<const char *> &strings, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
|
||||||
static NDArray* string_(char order, const std::vector<Nd4jLong> &shape, const std::initializer_list<std::string> &string, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
|
||||||
|
|
||||||
static NDArray* string_(char order, const std::vector<Nd4jLong> &shape, const std::vector<const char *> &strings, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
|
||||||
static NDArray* string_(char order, const std::vector<Nd4jLong> &shape, const std::vector<std::string> &string, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
|
||||||
|
|
||||||
static ResultSet createSetOfArrs(const Nd4jLong numOfArrs, const void* buffer, const Nd4jLong* shapeInfo, const Nd4jLong* offsets, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
static ResultSet createSetOfArrs(const Nd4jLong numOfArrs, const void* buffer, const Nd4jLong* shapeInfo, const Nd4jLong* offsets, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
||||||
|
|
||||||
|
|
|
@ -1518,7 +1518,7 @@ ND4J_EXPORT int execCustomOp2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPoi
|
||||||
typedef nd4j::ShapeList OpaqueShapeList;
|
typedef nd4j::ShapeList OpaqueShapeList;
|
||||||
|
|
||||||
ND4J_EXPORT OpaqueShapeList* calculateOutputShapes(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs);
|
ND4J_EXPORT OpaqueShapeList* calculateOutputShapes(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs);
|
||||||
ND4J_EXPORT OpaqueShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs);
|
ND4J_EXPORT OpaqueShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs, int *dArgs, int numDArgs);
|
||||||
|
|
||||||
ND4J_EXPORT Nd4jLong getShapeListSize(OpaqueShapeList* list);
|
ND4J_EXPORT Nd4jLong getShapeListSize(OpaqueShapeList* list);
|
||||||
ND4J_EXPORT Nd4jLong* getShape(OpaqueShapeList* list, Nd4jLong i);
|
ND4J_EXPORT Nd4jLong* getShape(OpaqueShapeList* list, Nd4jLong i);
|
||||||
|
@ -1607,6 +1607,7 @@ ND4J_EXPORT void setGraphContextInputArray(OpaqueContext* ptr, int index, void *
|
||||||
ND4J_EXPORT void setGraphContextOutputArray(OpaqueContext* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo);
|
ND4J_EXPORT void setGraphContextOutputArray(OpaqueContext* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo);
|
||||||
ND4J_EXPORT void setGraphContextInputBuffer(OpaqueContext* ptr, int index, OpaqueDataBuffer *buffer, void *shapeInfo, void *specialShapeInfo);
|
ND4J_EXPORT void setGraphContextInputBuffer(OpaqueContext* ptr, int index, OpaqueDataBuffer *buffer, void *shapeInfo, void *specialShapeInfo);
|
||||||
ND4J_EXPORT void setGraphContextOutputBuffer(OpaqueContext* ptr, int index, OpaqueDataBuffer *buffer, void *shapeInfo, void *specialShapeInfo);
|
ND4J_EXPORT void setGraphContextOutputBuffer(OpaqueContext* ptr, int index, OpaqueDataBuffer *buffer, void *shapeInfo, void *specialShapeInfo);
|
||||||
|
ND4J_EXPORT void setGraphContextDArguments(OpaqueContext* ptr, int *arguments, int numberOfArguments);
|
||||||
ND4J_EXPORT void setGraphContextTArguments(OpaqueContext* ptr, double *arguments, int numberOfArguments);
|
ND4J_EXPORT void setGraphContextTArguments(OpaqueContext* ptr, double *arguments, int numberOfArguments);
|
||||||
ND4J_EXPORT void setGraphContextIArguments(OpaqueContext* ptr, Nd4jLong *arguments, int numberOfArguments);
|
ND4J_EXPORT void setGraphContextIArguments(OpaqueContext* ptr, Nd4jLong *arguments, int numberOfArguments);
|
||||||
ND4J_EXPORT void setGraphContextBArguments(OpaqueContext* ptr, bool *arguments, int numberOfArguments);
|
ND4J_EXPORT void setGraphContextBArguments(OpaqueContext* ptr, bool *arguments, int numberOfArguments);
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2019-2020 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -16,6 +17,7 @@
|
||||||
|
|
||||||
//
|
//
|
||||||
// Created by GS <sgazeos@gmail.com> on 2018-12-20.
|
// Created by GS <sgazeos@gmail.com> on 2018-12-20.
|
||||||
|
// @author Oleg Semeniv <oleg.semeniv@gmail.com>
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <NDArrayFactory.h>
|
#include <NDArrayFactory.h>
|
||||||
|
@ -25,6 +27,9 @@
|
||||||
#include <ShapeUtils.h>
|
#include <ShapeUtils.h>
|
||||||
#include <type_traits>
|
#include <type_traits>
|
||||||
|
|
||||||
|
|
||||||
|
#include <StringUtils.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -85,45 +90,6 @@ namespace nd4j {
|
||||||
template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector<Nd4jLong> &shape, const std::vector<uint8_t>& data, nd4j::LaunchContext * context);
|
template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector<Nd4jLong> &shape, const std::vector<uint8_t>& data, nd4j::LaunchContext * context);
|
||||||
template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector<Nd4jLong> &shape, const std::vector<bool>& data, nd4j::LaunchContext * context);
|
template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector<Nd4jLong> &shape, const std::vector<bool>& data, nd4j::LaunchContext * context);
|
||||||
|
|
||||||
NDArray NDArrayFactory::string(const char *str, nd4j::LaunchContext * context) {
|
|
||||||
std::string s(str);
|
|
||||||
return string(s, context);
|
|
||||||
}
|
|
||||||
|
|
||||||
NDArray* NDArrayFactory::string_(const char *str, nd4j::LaunchContext * context) {
|
|
||||||
return string_(std::string(str), context);
|
|
||||||
}
|
|
||||||
|
|
||||||
NDArray NDArrayFactory::string(const std::string &str, nd4j::LaunchContext * context) {
|
|
||||||
|
|
||||||
auto headerLength = ShapeUtils::stringBufferHeaderRequirements(1);
|
|
||||||
|
|
||||||
std::shared_ptr<DataBuffer> pBuffer = std::make_shared<DataBuffer>(headerLength + str.length(), DataType::UTF8, context->getWorkspace(), true);
|
|
||||||
|
|
||||||
NDArray res(pBuffer, ShapeDescriptor::scalarDescriptor(DataType::UTF8), context);
|
|
||||||
|
|
||||||
int8_t* buffer = reinterpret_cast<int8_t*>(res.getBuffer());
|
|
||||||
|
|
||||||
auto offsets = reinterpret_cast<Nd4jLong *>(buffer);
|
|
||||||
offsets[0] = 0;
|
|
||||||
offsets[1] = str.length();
|
|
||||||
|
|
||||||
auto data = buffer + headerLength;
|
|
||||||
|
|
||||||
memcpy(data, str.c_str(), str.length());
|
|
||||||
|
|
||||||
res.tickWriteHost();
|
|
||||||
res.syncToDevice();
|
|
||||||
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
NDArray* NDArrayFactory::string_(const std::string &str, nd4j::LaunchContext * context) {
|
|
||||||
auto res = new NDArray();
|
|
||||||
*res = NDArrayFactory::string(str, context);
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
template<typename T>
|
template<typename T>
|
||||||
NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd4jLong> &shape, nd4j::LaunchContext * context) {
|
NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd4jLong> &shape, nd4j::LaunchContext * context) {
|
||||||
|
@ -551,91 +517,175 @@ template ND4J_EXPORT NDArray NDArrayFactory::create(uint8_t * buffer, const char
|
||||||
template ND4J_EXPORT NDArray NDArrayFactory::create(int8_t* buffer, const char order, const std::initializer_list<Nd4jLong>& shape, nd4j::LaunchContext * context);
|
template ND4J_EXPORT NDArray NDArrayFactory::create(int8_t* buffer, const char order, const std::initializer_list<Nd4jLong>& shape, nd4j::LaunchContext * context);
|
||||||
template ND4J_EXPORT NDArray NDArrayFactory::create(int16_t* buffer, const char order, const std::initializer_list<Nd4jLong>& shape, nd4j::LaunchContext * context);
|
template ND4J_EXPORT NDArray NDArrayFactory::create(int16_t* buffer, const char order, const std::initializer_list<Nd4jLong>& shape, nd4j::LaunchContext * context);
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////
|
||||||
NDArray NDArrayFactory::string(char order, const std::vector<Nd4jLong> &shape, const std::initializer_list<const char *> &strings, nd4j::LaunchContext * context) {
|
NDArray NDArrayFactory::string(const char16_t* u16string, nd4j::DataType dtype, nd4j::LaunchContext* context) {
|
||||||
std::vector<const char*> vec(strings);
|
return NDArray(u16string, dtype, context);
|
||||||
return NDArrayFactory::string(order, shape, vec, context);
|
|
||||||
}
|
}
|
||||||
|
/////////////////////////////////////////////////////////////////////////
|
||||||
NDArray NDArrayFactory::string(char order, const std::vector<Nd4jLong> &shape, const std::vector<const char *> &strings, nd4j::LaunchContext * context) {
|
NDArray* NDArrayFactory::string_(const char16_t* u16string, nd4j::DataType dtype, nd4j::LaunchContext* context) {
|
||||||
std::vector<std::string> vec(strings.size());
|
return string_(std::u16string(u16string), dtype, context);
|
||||||
int cnt = 0;
|
|
||||||
for (auto s:strings)
|
|
||||||
vec[cnt++] = std::string(s);
|
|
||||||
|
|
||||||
return NDArrayFactory::string(order, shape, vec, context);
|
|
||||||
}
|
}
|
||||||
|
/////////////////////////////////////////////////////////////////////////
|
||||||
|
NDArray* NDArrayFactory::string_(const std::u16string& u16string, nd4j::DataType dtype, nd4j::LaunchContext* context) {
|
||||||
NDArray NDArrayFactory::string(char order, const std::vector<Nd4jLong> &shape, const std::initializer_list<std::string> &string, nd4j::LaunchContext * context) {
|
|
||||||
std::vector<std::string> vec(string);
|
|
||||||
return NDArrayFactory::string(order, shape, vec, context);
|
|
||||||
}
|
|
||||||
|
|
||||||
NDArray* NDArrayFactory::string_(char order, const std::vector<Nd4jLong> &shape, const std::initializer_list<const char *> &strings, nd4j::LaunchContext * context) {
|
|
||||||
std::vector<const char*> vec(strings);
|
|
||||||
return NDArrayFactory::string_(order, shape, vec, context);
|
|
||||||
}
|
|
||||||
|
|
||||||
NDArray* NDArrayFactory::string_(char order, const std::vector<Nd4jLong> &shape, const std::vector<const char *> &strings, nd4j::LaunchContext * context) {
|
|
||||||
std::vector<std::string> vec(strings.size());
|
|
||||||
int cnt = 0;
|
|
||||||
for (auto s:strings)
|
|
||||||
vec[cnt++] = std::string(s);
|
|
||||||
|
|
||||||
return NDArrayFactory::string_(order, shape, vec, context);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
NDArray* NDArrayFactory::string_(char order, const std::vector<Nd4jLong> &shape, const std::initializer_list<std::string> &string, nd4j::LaunchContext * context) {
|
|
||||||
std::vector<std::string> vec(string);
|
|
||||||
return NDArrayFactory::string_(order, shape, vec, context);
|
|
||||||
}
|
|
||||||
|
|
||||||
NDArray NDArrayFactory::string(char order, const std::vector<Nd4jLong> &shape, const std::vector<std::string> &string, nd4j::LaunchContext * context) {
|
|
||||||
|
|
||||||
if (context == nullptr)
|
|
||||||
context = nd4j::LaunchContext ::defaultContext();
|
|
||||||
|
|
||||||
auto headerLength = ShapeUtils::stringBufferHeaderRequirements(string.size());
|
|
||||||
|
|
||||||
std::vector<Nd4jLong> offsets(string.size() + 1);
|
|
||||||
Nd4jLong dataLength = 0;
|
|
||||||
for (int e = 0; e < string.size(); e++) {
|
|
||||||
offsets[e] = dataLength;
|
|
||||||
dataLength += string[e].length();
|
|
||||||
}
|
|
||||||
offsets[string.size()] = dataLength;
|
|
||||||
|
|
||||||
std::shared_ptr<DataBuffer> pBuffer = std::make_shared<DataBuffer>(headerLength + dataLength, DataType::UTF8, context->getWorkspace(), true);
|
|
||||||
|
|
||||||
NDArray res(pBuffer, ShapeDescriptor(DataType::UTF8, order, shape), context);
|
|
||||||
res.setAttached(context->getWorkspace() != nullptr);
|
|
||||||
|
|
||||||
if (res.lengthOf() != string.size())
|
|
||||||
throw std::invalid_argument("Number of strings should match length of array");
|
|
||||||
|
|
||||||
memcpy(res.buffer(), offsets.data(), offsets.size() * sizeof(Nd4jLong));
|
|
||||||
|
|
||||||
auto data = static_cast<int8_t*>(res.buffer()) + headerLength;
|
|
||||||
int resLen = res.lengthOf();
|
|
||||||
for (int e = 0; e < resLen; e++) {
|
|
||||||
auto length = offsets[e+1] - offsets[e];
|
|
||||||
auto cdata = data + offsets[e];
|
|
||||||
memcpy(cdata, string[e].c_str(), string[e].length());
|
|
||||||
}
|
|
||||||
|
|
||||||
res.tickWriteHost();
|
|
||||||
res.syncToDevice();
|
|
||||||
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
NDArray* NDArrayFactory::string_(char order, const std::vector<Nd4jLong> &shape, const std::vector<std::string> &string, nd4j::LaunchContext * context) {
|
|
||||||
auto res = new NDArray();
|
auto res = new NDArray();
|
||||||
*res = NDArrayFactory::string(order, shape, string, context);
|
*res = NDArray(u16string, dtype, context);
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
/////////////////////////////////////////////////////////////////////////
|
||||||
|
NDArray NDArrayFactory::string(const std::u16string& u16string, nd4j::DataType dtype, nd4j::LaunchContext* context) {
|
||||||
|
return NDArray(u16string, dtype, context);
|
||||||
|
}
|
||||||
|
/////////////////////////////////////////////////////////////////////////
|
||||||
|
NDArray NDArrayFactory::string(const char32_t* u32string, nd4j::DataType dtype, nd4j::LaunchContext* context) {
|
||||||
|
return NDArray(u32string, dtype, context);
|
||||||
|
}
|
||||||
|
/////////////////////////////////////////////////////////////////////////
|
||||||
|
NDArray* NDArrayFactory::string_(const char32_t* u32string, nd4j::DataType dtype, nd4j::LaunchContext* context) {
|
||||||
|
return string_(std::u32string(u32string), dtype, context);
|
||||||
|
}
|
||||||
|
/////////////////////////////////////////////////////////////////////////
|
||||||
|
NDArray* NDArrayFactory::string_(const std::u32string& u32string, nd4j::DataType dtype, nd4j::LaunchContext* context) {
|
||||||
|
auto res = new NDArray();
|
||||||
|
*res = NDArray(u32string, dtype, context);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
/////////////////////////////////////////////////////////////////////////
|
||||||
|
NDArray NDArrayFactory::string(const std::u32string& u32string, nd4j::DataType dtype, nd4j::LaunchContext* context) {
|
||||||
|
return NDArray(u32string, dtype, context);
|
||||||
|
}
|
||||||
|
/////////////////////////////////////////////////////////////////////////
|
||||||
|
NDArray NDArrayFactory::string(const char* str, nd4j::DataType dtype, nd4j::LaunchContext* context) {
|
||||||
|
return NDArray(str, dtype, context);
|
||||||
|
}
|
||||||
|
/////////////////////////////////////////////////////////////////////////
|
||||||
|
NDArray* NDArrayFactory::string_(const char* str, nd4j::DataType dtype, nd4j::LaunchContext* context) {
|
||||||
|
return string_(std::string(str), dtype, context);
|
||||||
|
}
|
||||||
|
/////////////////////////////////////////////////////////////////////////
|
||||||
|
NDArray* NDArrayFactory::string_(const std::string& str, nd4j::DataType dtype, nd4j::LaunchContext* context) {
|
||||||
|
auto res = new NDArray();
|
||||||
|
*res = NDArray(str, dtype, context);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
/////////////////////////////////////////////////////////////////////////
|
||||||
|
NDArray NDArrayFactory::string(const std::string& str, nd4j::DataType dtype, nd4j::LaunchContext* context) {
|
||||||
|
return NDArray(str, dtype, context);
|
||||||
|
}
|
||||||
|
/////////////////////////////////////////////////////////////////////////
|
||||||
|
NDArray NDArrayFactory::string(const std::vector<Nd4jLong> &shape, const std::initializer_list<const char *> &strings, nd4j::DataType dataType, nd4j::LaunchContext * context) {
|
||||||
|
return NDArray(shape, std::vector<const char*>(strings), dataType, context);
|
||||||
|
}
|
||||||
|
/////////////////////////////////////////////////////////////////////////
|
||||||
|
NDArray NDArrayFactory::string( const std::vector<Nd4jLong> &shape, const std::vector<const char *> &strings, nd4j::DataType dataType, nd4j::LaunchContext * context) {
|
||||||
|
return NDArray( shape, strings, dataType, context);
|
||||||
|
}
|
||||||
|
/////////////////////////////////////////////////////////////////////////
|
||||||
|
NDArray NDArrayFactory::string( const std::vector<Nd4jLong> &shape, const std::initializer_list<std::string> &string, nd4j::DataType dataType, nd4j::LaunchContext * context) {
|
||||||
|
return NDArray( shape, std::vector<std::string>(string), dataType, context);
|
||||||
|
}
|
||||||
|
/////////////////////////////////////////////////////////////////////////
|
||||||
|
NDArray* NDArrayFactory::string_( const std::vector<Nd4jLong> &shape, const std::initializer_list<const char *> &strings, nd4j::DataType dataType, nd4j::LaunchContext * context) {
|
||||||
|
return NDArrayFactory::string_( shape, std::vector<const char*>(strings), dataType, context);
|
||||||
|
}
|
||||||
|
/////////////////////////////////////////////////////////////////////////
|
||||||
|
NDArray* NDArrayFactory::string_( const std::vector<Nd4jLong> &shape, const std::vector<const char *> &strings, nd4j::DataType dataType, nd4j::LaunchContext * context) {
|
||||||
|
std::vector<std::string> vec(strings.size());
|
||||||
|
int cnt = 0;
|
||||||
|
for (auto s:strings)
|
||||||
|
vec[cnt++] = std::string(s);
|
||||||
|
|
||||||
|
return NDArrayFactory::string_( shape, vec, dataType, context);
|
||||||
|
}
|
||||||
|
/////////////////////////////////////////////////////////////////////////
|
||||||
|
NDArray* NDArrayFactory::string_( const std::vector<Nd4jLong> &shape, const std::initializer_list<std::string> &string, nd4j::DataType dataType, nd4j::LaunchContext * context) {
|
||||||
|
return NDArrayFactory::string_( shape, std::vector<std::string>(string), dataType, context);
|
||||||
|
}
|
||||||
|
/////////////////////////////////////////////////////////////////////////
|
||||||
|
NDArray NDArrayFactory::string( const std::vector<Nd4jLong> &shape, const std::vector<std::string> &string, nd4j::DataType dataType, nd4j::LaunchContext * context) {
|
||||||
|
return NDArray(shape, string, dataType, context);
|
||||||
|
}
|
||||||
|
/////////////////////////////////////////////////////////////////////////
|
||||||
|
NDArray* NDArrayFactory::string_(const std::vector<Nd4jLong> &shape, const std::vector<std::string> &string, nd4j::DataType dataType, nd4j::LaunchContext * context) {
|
||||||
|
auto res = new NDArray();
|
||||||
|
*res = NDArray( shape, string, dataType, context);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
/////////////////////////////////////////////////////////////////////////
|
||||||
|
NDArray NDArrayFactory::string(const std::vector<Nd4jLong>& shape, const std::initializer_list<const char16_t*>& strings, nd4j::DataType dataType, nd4j::LaunchContext* context) {
|
||||||
|
return NDArray( shape, std::vector<const char16_t*>(strings), dataType, context);
|
||||||
|
}
|
||||||
|
/////////////////////////////////////////////////////////////////////////
|
||||||
|
NDArray NDArrayFactory::string( const std::vector<Nd4jLong>& shape, const std::vector<const char16_t*>& strings, nd4j::DataType dataType, nd4j::LaunchContext* context) {
|
||||||
|
return NDArray( shape, strings, dataType, context);
|
||||||
|
}
|
||||||
|
/////////////////////////////////////////////////////////////////////////
|
||||||
|
NDArray NDArrayFactory::string( const std::vector<Nd4jLong>& shape, const std::initializer_list<std::u16string>& string, nd4j::DataType dataType, nd4j::LaunchContext* context) {
|
||||||
|
return NDArray( shape, std::vector<std::u16string>(string), dataType, context);
|
||||||
|
}
|
||||||
|
/////////////////////////////////////////////////////////////////////////
|
||||||
|
NDArray* NDArrayFactory::string_( const std::vector<Nd4jLong>& shape, const std::initializer_list<const char16_t*>& strings, nd4j::DataType dataType, nd4j::LaunchContext* context) {
|
||||||
|
return NDArrayFactory::string_( shape, std::vector<const char16_t*>(strings), dataType, context);
|
||||||
|
}
|
||||||
|
/////////////////////////////////////////////////////////////////////////
|
||||||
|
NDArray* NDArrayFactory::string_( const std::vector<Nd4jLong>& shape, const std::vector<const char16_t*>& strings, nd4j::DataType dataType, nd4j::LaunchContext* context) {
|
||||||
|
std::vector<std::u16string> vec(strings.size());
|
||||||
|
int cnt = 0;
|
||||||
|
for (auto s : strings)
|
||||||
|
vec[cnt++] = std::u16string(s);
|
||||||
|
|
||||||
|
return NDArrayFactory::string_( shape, vec, dataType, context);
|
||||||
|
}
|
||||||
|
/////////////////////////////////////////////////////////////////////////
|
||||||
|
NDArray* NDArrayFactory::string_( const std::vector<Nd4jLong>& shape, const std::initializer_list<std::u16string>& string, nd4j::DataType dataType, nd4j::LaunchContext* context) {
|
||||||
|
return NDArrayFactory::string_( shape, std::vector<std::u16string>(string), dataType, context);
|
||||||
|
}
|
||||||
|
/////////////////////////////////////////////////////////////////////////
|
||||||
|
NDArray* NDArrayFactory::string_( const std::vector<Nd4jLong>& shape, const std::vector<std::u16string>& string, nd4j::DataType dataType, nd4j::LaunchContext* context) {
|
||||||
|
auto res = new NDArray();
|
||||||
|
*res = NDArray( shape, string, dataType, context);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
/////////////////////////////////////////////////////////////////////////
|
||||||
|
NDArray NDArrayFactory::string( const std::vector<Nd4jLong>& shape, const std::vector<std::u16string>& string, nd4j::DataType dtype, nd4j::LaunchContext* context) {
|
||||||
|
return NDArray( shape, string, dtype, context);
|
||||||
|
}
|
||||||
|
/////////////////////////////////////////////////////////////////////////
|
||||||
|
NDArray NDArrayFactory::string( const std::vector<Nd4jLong>& shape, const std::initializer_list<const char32_t*>& strings, nd4j::DataType dataType, nd4j::LaunchContext* context) {
|
||||||
|
return NDArray( shape, std::vector<const char32_t*>(strings), dataType, context);
|
||||||
|
}
|
||||||
|
/////////////////////////////////////////////////////////////////////////
|
||||||
|
NDArray NDArrayFactory::string( const std::vector<Nd4jLong>& shape, const std::vector<const char32_t*>& strings, nd4j::DataType dataType, nd4j::LaunchContext* context) {
|
||||||
|
return NDArray( shape, strings, dataType, context);
|
||||||
|
}
|
||||||
|
/////////////////////////////////////////////////////////////////////////
|
||||||
|
NDArray NDArrayFactory::string( const std::vector<Nd4jLong>& shape, const std::initializer_list<std::u32string>& string, nd4j::DataType dataType, nd4j::LaunchContext* context) {
|
||||||
|
return NDArray(shape, std::vector<std::u32string>(string), dataType, context);
|
||||||
|
}
|
||||||
|
/////////////////////////////////////////////////////////////////////////
|
||||||
|
NDArray* NDArrayFactory::string_( const std::vector<Nd4jLong>& shape, const std::initializer_list<const char32_t*>& strings, nd4j::DataType dataType, nd4j::LaunchContext* context) {
|
||||||
|
return NDArrayFactory::string_( shape, std::vector<const char32_t*>(strings), dataType, context);
|
||||||
|
}
|
||||||
|
/////////////////////////////////////////////////////////////////////////
|
||||||
|
NDArray* NDArrayFactory::string_( const std::vector<Nd4jLong>& shape, const std::vector<const char32_t*>& strings, nd4j::DataType dataType, nd4j::LaunchContext* context) {
|
||||||
|
std::vector<std::u32string> vec(strings.size());
|
||||||
|
int cnt = 0;
|
||||||
|
for (auto s : strings)
|
||||||
|
vec[cnt++] = std::u32string(s);
|
||||||
|
return NDArrayFactory::string_( shape, vec, dataType, context);
|
||||||
|
}
|
||||||
|
/////////////////////////////////////////////////////////////////////////
|
||||||
|
NDArray* NDArrayFactory::string_( const std::vector<Nd4jLong>& shape, const std::initializer_list<std::u32string>& string, nd4j::DataType dataType, nd4j::LaunchContext* context) {
|
||||||
|
return NDArrayFactory::string_( shape, std::vector<std::u32string>(string), dataType, context);
|
||||||
|
}
|
||||||
|
/////////////////////////////////////////////////////////////////////////
|
||||||
|
NDArray* NDArrayFactory::string_( const std::vector<Nd4jLong>& shape, const std::vector<std::u32string>& string, nd4j::DataType dataType, nd4j::LaunchContext* context) {
|
||||||
|
auto res = new NDArray();
|
||||||
|
*res = NDArray( shape, string, dataType, context);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
/////////////////////////////////////////////////////////////////////////
|
||||||
|
NDArray NDArrayFactory::string(const std::vector<Nd4jLong>& shape, const std::vector<std::u32string>& string, nd4j::DataType dtype, nd4j::LaunchContext* context) {
|
||||||
|
return NDArray( shape, string, dtype, context);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1974,7 +1974,7 @@ void deleteShapeList(Nd4jPointer shapeList) {
|
||||||
delete list;
|
delete list;
|
||||||
}
|
}
|
||||||
|
|
||||||
nd4j::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, nd4j::ops::DeclarableOp* op, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs) {
|
nd4j::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, nd4j::ops::DeclarableOp* op, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs, int *dArgs, int numDArgs) {
|
||||||
nd4j::graph::VariableSpace varSpace;
|
nd4j::graph::VariableSpace varSpace;
|
||||||
Context block(2, &varSpace);
|
Context block(2, &varSpace);
|
||||||
nd4j::ShapeList inShapes;
|
nd4j::ShapeList inShapes;
|
||||||
|
@ -1988,6 +1988,9 @@ nd4j::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, nd4j::ops::D
|
||||||
for (int e = 0; e < numBArgs; e++)
|
for (int e = 0; e < numBArgs; e++)
|
||||||
block.getBArguments()->push_back(bArgs[e]);
|
block.getBArguments()->push_back(bArgs[e]);
|
||||||
|
|
||||||
|
for (int e = 0; e < numDArgs; e++)
|
||||||
|
block.getDArguments()->push_back((nd4j::DataType) dArgs[e]);
|
||||||
|
|
||||||
for (int e = 0; e < numInputShapes; e++) {
|
for (int e = 0; e < numInputShapes; e++) {
|
||||||
auto shape_ = reinterpret_cast<Nd4jLong *>(inputShapes[e]);
|
auto shape_ = reinterpret_cast<Nd4jLong *>(inputShapes[e]);
|
||||||
|
|
||||||
|
@ -2015,11 +2018,11 @@ nd4j::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, nd4j::ops::D
|
||||||
return shapeList;
|
return shapeList;
|
||||||
}
|
}
|
||||||
|
|
||||||
nd4j::ShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs) {
|
nd4j::ShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs, int *dArgs, int numDArgs) {
|
||||||
try {
|
try {
|
||||||
auto op = nd4j::ops::OpRegistrator::getInstance()->getOperation(hash);
|
auto op = nd4j::ops::OpRegistrator::getInstance()->getOperation(hash);
|
||||||
|
|
||||||
return _calculateOutputShapes(extraPointers, op, inputBuffers, inputShapes, numInputShapes, tArgs, numTArgs, iArgs, numIArgs, bArgs, numBArgs);
|
return _calculateOutputShapes(extraPointers, op, inputBuffers, inputShapes, numInputShapes, tArgs, numTArgs, iArgs, numIArgs, bArgs, numBArgs, dArgs, numDArgs);
|
||||||
} catch (std::exception &e) {
|
} catch (std::exception &e) {
|
||||||
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
||||||
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
||||||
|
@ -2130,7 +2133,7 @@ Nd4jStatus realExec(nd4j::ops::DeclarableOp* op, Nd4jPointer* extraPointers, Nd4
|
||||||
biArgs[e] = bArgs[e];
|
biArgs[e] = bArgs[e];
|
||||||
|
|
||||||
// hypothetically at this point we have everything filled
|
// hypothetically at this point we have everything filled
|
||||||
auto hZ = op->execute(inputs, outputs, ttArgs, iiArgs, biArgs, isInplace);
|
auto hZ = op->execute(inputs, outputs, ttArgs, iiArgs, biArgs, std::vector<nd4j::DataType>(), isInplace);
|
||||||
//auto hZ = op->execute(inputs, ttArgs, iiArgs, isInplace);
|
//auto hZ = op->execute(inputs, ttArgs, iiArgs, isInplace);
|
||||||
|
|
||||||
|
|
||||||
|
@ -2788,6 +2791,15 @@ void setGraphContextIArguments(nd4j::graph::Context* ptr, Nd4jLong *arguments, i
|
||||||
void setGraphContextBArguments(nd4j::graph::Context* ptr, bool *arguments, int numberOfArguments) {
|
void setGraphContextBArguments(nd4j::graph::Context* ptr, bool *arguments, int numberOfArguments) {
|
||||||
ptr->setBArguments(arguments, numberOfArguments);
|
ptr->setBArguments(arguments, numberOfArguments);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void setGraphContextDArguments(OpaqueContext* ptr, int *arguments, int numberOfArguments) {
|
||||||
|
std::vector<nd4j::DataType> dtypes(numberOfArguments);
|
||||||
|
for (int e = 0; e < numberOfArguments; e++)
|
||||||
|
dtypes[e] = (nd4j::DataType) arguments[e];
|
||||||
|
|
||||||
|
ptr->setDArguments(dtypes);
|
||||||
|
}
|
||||||
|
|
||||||
void deleteGraphContext(nd4j::graph::Context* ptr) {
|
void deleteGraphContext(nd4j::graph::Context* ptr) {
|
||||||
delete ptr;
|
delete ptr;
|
||||||
}
|
}
|
||||||
|
|
|
@ -2684,7 +2684,7 @@ const char* getAllCustomOps() {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
nd4j::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, nd4j::ops::DeclarableOp* op, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs) {
|
nd4j::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, nd4j::ops::DeclarableOp* op, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs, int *dArgs, int numDArgs) {
|
||||||
nd4j::graph::VariableSpace varSpace;
|
nd4j::graph::VariableSpace varSpace;
|
||||||
Context block(2, &varSpace);
|
Context block(2, &varSpace);
|
||||||
nd4j::ShapeList inShapes;
|
nd4j::ShapeList inShapes;
|
||||||
|
@ -2698,6 +2698,9 @@ nd4j::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, nd4j::ops::D
|
||||||
for (int e = 0; e < numBArgs; e++)
|
for (int e = 0; e < numBArgs; e++)
|
||||||
block.getBArguments()->push_back(bArgs[e]);
|
block.getBArguments()->push_back(bArgs[e]);
|
||||||
|
|
||||||
|
for (int e = 0; e < numDArgs; e++)
|
||||||
|
block.getDArguments()->push_back((nd4j::DataType) dArgs[e]);
|
||||||
|
|
||||||
for (int e = 0; e < numInputShapes; e++) {
|
for (int e = 0; e < numInputShapes; e++) {
|
||||||
auto shape_ = reinterpret_cast<Nd4jLong *>(inputShapes[e]);
|
auto shape_ = reinterpret_cast<Nd4jLong *>(inputShapes[e]);
|
||||||
|
|
||||||
|
@ -2722,12 +2725,12 @@ nd4j::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, nd4j::ops::D
|
||||||
return shapeList;
|
return shapeList;
|
||||||
}
|
}
|
||||||
|
|
||||||
nd4j::ShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs) {
|
nd4j::ShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs, int *dArgs, int numDArgs) {
|
||||||
try {
|
try {
|
||||||
auto op = nd4j::ops::OpRegistrator::getInstance()->getOperation(hash);
|
auto op = nd4j::ops::OpRegistrator::getInstance()->getOperation(hash);
|
||||||
|
|
||||||
return _calculateOutputShapes(extraPointers, op, inputBuffers, inputShapes, numInputShapes, tArgs, numTArgs,
|
return _calculateOutputShapes(extraPointers, op, inputBuffers, inputShapes, numInputShapes, tArgs, numTArgs,
|
||||||
iArgs, numIArgs, bArgs, numBArgs);
|
iArgs, numIArgs, bArgs, numBArgs, dArgs, numDArgs);
|
||||||
} catch (std::exception &e) {
|
} catch (std::exception &e) {
|
||||||
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
||||||
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
||||||
|
@ -2831,7 +2834,7 @@ static FORCEINLINE Nd4jStatus realExec(nd4j::ops::DeclarableOp* op, Nd4jPointer*
|
||||||
|
|
||||||
|
|
||||||
// hypothetically at this point we have everything filled
|
// hypothetically at this point we have everything filled
|
||||||
auto dZ = op->execute(inputs, outputs, ttArgs, iiArgs, bbArgs, isInplace);
|
auto dZ = op->execute(inputs, outputs, ttArgs, iiArgs, bbArgs, std::vector<nd4j::DataType>(), isInplace);
|
||||||
//auto dZ = op->execute(inputs, ttArgs, iiArgs, isInplace);
|
//auto dZ = op->execute(inputs, ttArgs, iiArgs, isInplace);
|
||||||
|
|
||||||
|
|
||||||
|
@ -3596,6 +3599,14 @@ void setGraphContextBArguments(nd4j::graph::Context* ptr, bool *arguments, int n
|
||||||
ptr->setBArguments(arguments, numberOfArguments);
|
ptr->setBArguments(arguments, numberOfArguments);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void setGraphContextDArguments(OpaqueContext* ptr, int *arguments, int numberOfArguments) {
|
||||||
|
std::vector<nd4j::DataType> dtypes(numberOfArguments);
|
||||||
|
for (int e = 0; e < numberOfArguments; e++)
|
||||||
|
dtypes[e] = (nd4j::DataType) arguments[e];
|
||||||
|
|
||||||
|
ptr->setDArguments(dtypes);
|
||||||
|
}
|
||||||
|
|
||||||
void deleteGraphContext(nd4j::graph::Context* ptr) {
|
void deleteGraphContext(nd4j::graph::Context* ptr) {
|
||||||
delete ptr;
|
delete ptr;
|
||||||
}
|
}
|
||||||
|
|
|
@ -55,6 +55,7 @@ TESTS="false"
|
||||||
VERBOSE="false"
|
VERBOSE="false"
|
||||||
VERBOSE_ARG="VERBOSE=1"
|
VERBOSE_ARG="VERBOSE=1"
|
||||||
HELPER=
|
HELPER=
|
||||||
|
CHECK_VECTORIZATION="OFF"
|
||||||
NAME=
|
NAME=
|
||||||
while [[ $# > 0 ]]
|
while [[ $# > 0 ]]
|
||||||
do
|
do
|
||||||
|
@ -114,6 +115,9 @@ case $key in
|
||||||
NAME="$value"
|
NAME="$value"
|
||||||
shift # past argument
|
shift # past argument
|
||||||
;;
|
;;
|
||||||
|
--check-vectorization)
|
||||||
|
CHECK_VECTORIZATION="ON"
|
||||||
|
;;
|
||||||
-j)
|
-j)
|
||||||
MAKEJ="$value"
|
MAKEJ="$value"
|
||||||
shift # past argument
|
shift # past argument
|
||||||
|
@ -528,14 +532,27 @@ echo MINIFIER = "${MINIFIER_ARG}"
|
||||||
echo TESTS = "${TESTS_ARG}"
|
echo TESTS = "${TESTS_ARG}"
|
||||||
echo NAME = "${NAME_ARG}"
|
echo NAME = "${NAME_ARG}"
|
||||||
echo OPENBLAS_PATH = "$OPENBLAS_PATH"
|
echo OPENBLAS_PATH = "$OPENBLAS_PATH"
|
||||||
|
echo CHECK_VECTORIZATION = "$CHECK_VECTORIZATION"
|
||||||
echo HELPERS = "$HELPERS"
|
echo HELPERS = "$HELPERS"
|
||||||
mkbuilddir
|
mkbuilddir
|
||||||
pwd
|
pwd
|
||||||
eval $CMAKE_COMMAND "$BLAS_ARG" "$ARCH_ARG" "$NAME_ARG" $HELPERS "$SHARED_LIBS_ARG" "$MINIFIER_ARG" "$OPERATIONS_ARG" "$BUILD_TYPE" "$PACKAGING_ARG" "$EXPERIMENTAL_ARG" "$TESTS_ARG" "$CUDA_COMPUTE" -DOPENBLAS_PATH="$OPENBLAS_PATH" -DDEV=FALSE -DCMAKE_NEED_RESPONSE=YES -DMKL_MULTI_THREADED=TRUE ../..
|
eval $CMAKE_COMMAND "$BLAS_ARG" "$ARCH_ARG" "$NAME_ARG" -DCHECK_VECTORIZATION="${CHECK_VECTORIZATION}" $HELPERS "$SHARED_LIBS_ARG" "$MINIFIER_ARG" "$OPERATIONS_ARG" "$BUILD_TYPE" "$PACKAGING_ARG" "$EXPERIMENTAL_ARG" "$TESTS_ARG" "$CUDA_COMPUTE" -DOPENBLAS_PATH="$OPENBLAS_PATH" -DDEV=FALSE -DCMAKE_NEED_RESPONSE=YES -DMKL_MULTI_THREADED=TRUE ../..
|
||||||
|
|
||||||
if [ "$PARALLEL" == "true" ]; then
|
if [ "$PARALLEL" == "true" ]; then
|
||||||
MAKE_ARGUMENTS="$MAKE_ARGUMENTS -j $MAKEJ"
|
MAKE_ARGUMENTS="$MAKE_ARGUMENTS -j $MAKEJ"
|
||||||
fi
|
fi
|
||||||
if [ "$VERBOSE" == "true" ]; then
|
if [ "$VERBOSE" == "true" ]; then
|
||||||
MAKE_ARGUMENTS="$MAKE_ARGUMENTS $VERBOSE_ARG"
|
MAKE_ARGUMENTS="$MAKE_ARGUMENTS $VERBOSE_ARG"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
if [ "$CHECK_VECTORIZATION" == "ON" ]; then
|
||||||
|
|
||||||
|
if [ "$MAKE_COMMAND" == "make" ]; then
|
||||||
|
MAKE_ARGUMENTS="$MAKE_ARGUMENTS --output-sync=target"
|
||||||
|
fi
|
||||||
|
exec 3>&1
|
||||||
|
eval $MAKE_COMMAND $MAKE_ARGUMENTS 2>&1 >&3 3>&- | python3 ../../auto_vectorization/auto_vect.py && cd ../../..
|
||||||
|
exec 3>&-
|
||||||
|
else
|
||||||
eval $MAKE_COMMAND $MAKE_ARGUMENTS && cd ../../..
|
eval $MAKE_COMMAND $MAKE_ARGUMENTS && cd ../../..
|
||||||
|
fi
|
||||||
|
|
|
@ -95,6 +95,10 @@ namespace nd4j {
|
||||||
template<typename T>
|
template<typename T>
|
||||||
// struct scalarTypesForNDarray { static bool const value = std::is_same<double, T>::value || std::is_same<float, T>::value || std::is_same<int, T>::value || std::is_same<bfloat16, T>::value || std::is_same<float16, T>::value || std::is_same<long long, T>::value; };
|
// struct scalarTypesForNDarray { static bool const value = std::is_same<double, T>::value || std::is_same<float, T>::value || std::is_same<int, T>::value || std::is_same<bfloat16, T>::value || std::is_same<float16, T>::value || std::is_same<long long, T>::value; };
|
||||||
struct scalarTypesForNDarray { static bool const value = std::is_same<double, T>::value || std::is_same<float, T>::value || std::is_same<int, T>::value || std::is_same<unsigned int, T>::value || std::is_same<long long, T>::value || std::is_same<unsigned long long, T>::value || std::is_same<long int, T>::value || std::is_same<long unsigned int, T>::value || std::is_same<int8_t, T>::value || std::is_same<uint8_t, T>::value || std::is_same<int16_t, T>::value || std::is_same<uint16_t, T>::value || std::is_same<bool, T>::value || std::is_same<bfloat16, T>::value || std::is_same<float16, T>::value; };
|
struct scalarTypesForNDarray { static bool const value = std::is_same<double, T>::value || std::is_same<float, T>::value || std::is_same<int, T>::value || std::is_same<unsigned int, T>::value || std::is_same<long long, T>::value || std::is_same<unsigned long long, T>::value || std::is_same<long int, T>::value || std::is_same<long unsigned int, T>::value || std::is_same<int8_t, T>::value || std::is_same<uint8_t, T>::value || std::is_same<int16_t, T>::value || std::is_same<uint16_t, T>::value || std::is_same<bool, T>::value || std::is_same<bfloat16, T>::value || std::is_same<float16, T>::value; };
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
struct scalarTypesForExecution { static bool const value = std::is_same<double, T>::value || std::is_same<float, T>::value || std::is_same<Nd4jLong, T>::value || std::is_same<int, T>::value || std::is_same<bool, T>::value; };
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
@ -118,7 +122,7 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
|
|
||||||
FORCEINLINE bool DataTypeUtils::isS(nd4j::DataType dataType) {
|
FORCEINLINE bool DataTypeUtils::isS(nd4j::DataType dataType) {
|
||||||
return dataType == nd4j::DataType::UTF8;
|
return dataType == nd4j::DataType::UTF8 || dataType == nd4j::DataType::UTF16 || dataType == nd4j::DataType::UTF32;
|
||||||
}
|
}
|
||||||
|
|
||||||
FORCEINLINE bool DataTypeUtils::isZ(nd4j::DataType dataType) {
|
FORCEINLINE bool DataTypeUtils::isZ(nd4j::DataType dataType) {
|
||||||
|
@ -366,6 +370,10 @@ FORCEINLINE std::string DataTypeUtils::asString(DataType dataType) {
|
||||||
return std::string("UINT64");
|
return std::string("UINT64");
|
||||||
case UTF8:
|
case UTF8:
|
||||||
return std::string("UTF8");
|
return std::string("UTF8");
|
||||||
|
case UTF16:
|
||||||
|
return std::string("UTF16");
|
||||||
|
case UTF32:
|
||||||
|
return std::string("UTF32");
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error("Unknown data type used");
|
throw std::runtime_error("Unknown data type used");
|
||||||
}
|
}
|
||||||
|
@ -427,6 +435,8 @@ FORCEINLINE _CUDA_HD T DataTypeUtils::eps() {
|
||||||
case nd4j::DataType::UINT16: return (size_t) 2;
|
case nd4j::DataType::UINT16: return (size_t) 2;
|
||||||
|
|
||||||
case nd4j::DataType::UTF8:
|
case nd4j::DataType::UTF8:
|
||||||
|
case nd4j::DataType::UTF16:
|
||||||
|
case nd4j::DataType::UTF32:
|
||||||
case nd4j::DataType::INT32:
|
case nd4j::DataType::INT32:
|
||||||
case nd4j::DataType::UINT32:
|
case nd4j::DataType::UINT32:
|
||||||
case nd4j::DataType::HALF2:
|
case nd4j::DataType::HALF2:
|
||||||
|
@ -451,6 +461,10 @@ FORCEINLINE _CUDA_HD T DataTypeUtils::eps() {
|
||||||
return nd4j::DataType::BOOL;
|
return nd4j::DataType::BOOL;
|
||||||
} else if (std::is_same<T, std::string>::value) {
|
} else if (std::is_same<T, std::string>::value) {
|
||||||
return nd4j::DataType::UTF8;
|
return nd4j::DataType::UTF8;
|
||||||
|
} else if (std::is_same<T, std::u16string>::value) {
|
||||||
|
return nd4j::DataType::UTF16;
|
||||||
|
} else if (std::is_same<T, std::u32string>::value) {
|
||||||
|
return nd4j::DataType::UTF32;
|
||||||
} else if (std::is_same<T, float>::value) {
|
} else if (std::is_same<T, float>::value) {
|
||||||
return nd4j::DataType::FLOAT32;
|
return nd4j::DataType::FLOAT32;
|
||||||
} else if (std::is_same<T, float16>::value) {
|
} else if (std::is_same<T, float16>::value) {
|
||||||
|
|
|
@ -158,7 +158,7 @@ namespace nd4j {
|
||||||
|
|
||||||
iargs.push_back(_axis);
|
iargs.push_back(_axis);
|
||||||
|
|
||||||
auto result = op.execute(inputs, {}, {}, {});
|
auto result = op.evaluate(inputs);
|
||||||
|
|
||||||
auto array = new NDArray(result->at(0)->dup());
|
auto array = new NDArray(result->at(0)->dup());
|
||||||
|
|
||||||
|
|
|
@ -197,10 +197,12 @@ namespace nd4j {
|
||||||
void setTArguments(double *arguments, int numberOfArguments);
|
void setTArguments(double *arguments, int numberOfArguments);
|
||||||
void setIArguments(Nd4jLong *arguments, int numberOfArguments);
|
void setIArguments(Nd4jLong *arguments, int numberOfArguments);
|
||||||
void setBArguments(bool *arguments, int numberOfArguments);
|
void setBArguments(bool *arguments, int numberOfArguments);
|
||||||
|
void setDArguments(nd4j::DataType *arguments, int numberOfArguments);
|
||||||
|
|
||||||
void setTArguments(const std::vector<double> &tArgs);
|
void setTArguments(const std::vector<double> &tArgs);
|
||||||
void setIArguments(const std::vector<Nd4jLong> &tArgs);
|
void setIArguments(const std::vector<Nd4jLong> &tArgs);
|
||||||
void setBArguments(const std::vector<bool> &tArgs);
|
void setBArguments(const std::vector<bool> &tArgs);
|
||||||
|
void setDArguments(const std::vector<nd4j::DataType> &dArgs);
|
||||||
|
|
||||||
void setCudaContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer, Nd4jPointer allocationPointer);
|
void setCudaContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer, Nd4jPointer allocationPointer);
|
||||||
|
|
||||||
|
|
|
@ -47,6 +47,9 @@ namespace nd4j {
|
||||||
std::vector<int> _iArgs;
|
std::vector<int> _iArgs;
|
||||||
std::vector<bool> _bArgs;
|
std::vector<bool> _bArgs;
|
||||||
std::vector<int> _axis;
|
std::vector<int> _axis;
|
||||||
|
std::vector<nd4j::DataType> _dArgs;
|
||||||
|
|
||||||
|
// TODO: remove this field
|
||||||
nd4j::DataType _dataType = nd4j::DataType::FLOAT32;
|
nd4j::DataType _dataType = nd4j::DataType::FLOAT32;
|
||||||
bool _isInplace;
|
bool _isInplace;
|
||||||
|
|
||||||
|
@ -93,6 +96,7 @@ namespace nd4j {
|
||||||
std::vector<double>* getTArguments();
|
std::vector<double>* getTArguments();
|
||||||
std::vector<int>* getIArguments();
|
std::vector<int>* getIArguments();
|
||||||
std::vector<bool>* getBArguments();
|
std::vector<bool>* getBArguments();
|
||||||
|
std::vector<nd4j::DataType>* getDArguments();
|
||||||
std::vector<int>* getAxis();
|
std::vector<int>* getAxis();
|
||||||
|
|
||||||
samediff::Engine engine();
|
samediff::Engine engine();
|
||||||
|
@ -100,6 +104,7 @@ namespace nd4j {
|
||||||
size_t numT();
|
size_t numT();
|
||||||
size_t numI();
|
size_t numI();
|
||||||
size_t numB();
|
size_t numB();
|
||||||
|
size_t numD();
|
||||||
|
|
||||||
std::pair<int, int>* input(int idx);
|
std::pair<int, int>* input(int idx);
|
||||||
|
|
||||||
|
|
|
@ -38,7 +38,9 @@ namespace nd4j {
|
||||||
|
|
||||||
class ND4J_EXPORT Node {
|
class ND4J_EXPORT Node {
|
||||||
protected:
|
protected:
|
||||||
|
// TODO: this field must be removed
|
||||||
nd4j::DataType _dataType;
|
nd4j::DataType _dataType;
|
||||||
|
|
||||||
OpType _opType;
|
OpType _opType;
|
||||||
ContextPrototype* _protoContext = nullptr;
|
ContextPrototype* _protoContext = nullptr;
|
||||||
Nd4jLong _opNum;
|
Nd4jLong _opNum;
|
||||||
|
@ -61,6 +63,7 @@ namespace nd4j {
|
||||||
|
|
||||||
|
|
||||||
// optional scalar. used in scalar ops and in summary stats
|
// optional scalar. used in scalar ops and in summary stats
|
||||||
|
// TODO: this field must be removed
|
||||||
NDArray _scalar;
|
NDArray _scalar;
|
||||||
|
|
||||||
bool _hasExternalOutputs;
|
bool _hasExternalOutputs;
|
||||||
|
@ -87,15 +90,15 @@ namespace nd4j {
|
||||||
int _scope_id = 0;
|
int _scope_id = 0;
|
||||||
std::string _scope_name;
|
std::string _scope_name;
|
||||||
|
|
||||||
|
// TODO: these 3 fields should be removed
|
||||||
int _rewindNode = -1;
|
int _rewindNode = -1;
|
||||||
std::pair<int, int> _rewindLayer = {-1, -1};
|
std::pair<int, int> _rewindLayer = {-1, -1};
|
||||||
|
|
||||||
Nd4jLong _frameId = -1;
|
Nd4jLong _frameId = -1;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
Node(nd4j::ops::DeclarableOp *customOp, int id = 0, std::initializer_list<int> input = {}, std::initializer_list<int> output = {}, std::initializer_list<int> dimensions = {}, float scalar = 0.0f, std::initializer_list<double> tArgs = {}, std::initializer_list<int> iArgs = {});
|
explicit Node(nd4j::ops::DeclarableOp *customOp, int id = 0, std::initializer_list<int> input = {}, std::initializer_list<int> output = {}, std::initializer_list<int> dimensions = {}, float scalar = 0.0f, std::initializer_list<double> tArgs = {}, std::initializer_list<int> iArgs = {});
|
||||||
Node(OpType opType = OpType_TRANSFORM_SAME, int opNum = 0, int id = 0, std::initializer_list<int> input = {}, std::initializer_list<int> output = {}, std::initializer_list<int> dimensions = {}, float scalar = 0.0f, std::initializer_list<double> tArgs = {}, std::initializer_list<int> iArgs = {});
|
explicit Node(OpType opType = OpType_TRANSFORM_SAME, int opNum = 0, int id = 0, std::initializer_list<int> input = {}, std::initializer_list<int> output = {}, std::initializer_list<int> dimensions = {}, float scalar = 0.0f, std::initializer_list<double> tArgs = {}, std::initializer_list<int> iArgs = {});
|
||||||
Node(const nd4j::graph::FlatNode *node);
|
explicit Node(const nd4j::graph::FlatNode *node);
|
||||||
~Node();
|
~Node();
|
||||||
|
|
||||||
bool equals(Node *other);
|
bool equals(Node *other);
|
||||||
|
|
|
@ -60,11 +60,13 @@ enum DType {
|
||||||
DType_QINT16 = 16,
|
DType_QINT16 = 16,
|
||||||
DType_BFLOAT16 = 17,
|
DType_BFLOAT16 = 17,
|
||||||
DType_UTF8 = 50,
|
DType_UTF8 = 50,
|
||||||
|
DType_UTF16 = 51,
|
||||||
|
DType_UTF32 = 52,
|
||||||
DType_MIN = DType_INHERIT,
|
DType_MIN = DType_INHERIT,
|
||||||
DType_MAX = DType_UTF8
|
DType_MAX = DType_UTF32
|
||||||
};
|
};
|
||||||
|
|
||||||
inline const DType (&EnumValuesDType())[19] {
|
inline const DType (&EnumValuesDType())[21] {
|
||||||
static const DType values[] = {
|
static const DType values[] = {
|
||||||
DType_INHERIT,
|
DType_INHERIT,
|
||||||
DType_BOOL,
|
DType_BOOL,
|
||||||
|
@ -84,7 +86,9 @@ inline const DType (&EnumValuesDType())[19] {
|
||||||
DType_QINT8,
|
DType_QINT8,
|
||||||
DType_QINT16,
|
DType_QINT16,
|
||||||
DType_BFLOAT16,
|
DType_BFLOAT16,
|
||||||
DType_UTF8
|
DType_UTF8,
|
||||||
|
DType_UTF16,
|
||||||
|
DType_UTF32
|
||||||
};
|
};
|
||||||
return values;
|
return values;
|
||||||
}
|
}
|
||||||
|
@ -142,6 +146,8 @@ inline const char * const *EnumNamesDType() {
|
||||||
"",
|
"",
|
||||||
"",
|
"",
|
||||||
"UTF8",
|
"UTF8",
|
||||||
|
"UTF16",
|
||||||
|
"UTF32",
|
||||||
nullptr
|
nullptr
|
||||||
};
|
};
|
||||||
return names;
|
return names;
|
||||||
|
|
|
@ -42,7 +42,9 @@ nd4j.graph.DType = {
|
||||||
QINT8: 15,
|
QINT8: 15,
|
||||||
QINT16: 16,
|
QINT16: 16,
|
||||||
BFLOAT16: 17,
|
BFLOAT16: 17,
|
||||||
UTF8: 50
|
UTF8: 50,
|
||||||
|
UTF16: 51,
|
||||||
|
UTF32: 52
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -26,6 +26,8 @@ public enum DType : sbyte
|
||||||
QINT16 = 16,
|
QINT16 = 16,
|
||||||
BFLOAT16 = 17,
|
BFLOAT16 = 17,
|
||||||
UTF8 = 50,
|
UTF8 = 50,
|
||||||
|
UTF16 = 51,
|
||||||
|
UTF32 = 52,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -23,8 +23,10 @@ public final class DType {
|
||||||
public static final byte QINT16 = 16;
|
public static final byte QINT16 = 16;
|
||||||
public static final byte BFLOAT16 = 17;
|
public static final byte BFLOAT16 = 17;
|
||||||
public static final byte UTF8 = 50;
|
public static final byte UTF8 = 50;
|
||||||
|
public static final byte UTF16 = 51;
|
||||||
|
public static final byte UTF32 = 52;
|
||||||
|
|
||||||
public static final String[] names = { "INHERIT", "BOOL", "FLOAT8", "HALF", "HALF2", "FLOAT", "DOUBLE", "INT8", "INT16", "INT32", "INT64", "UINT8", "UINT16", "UINT32", "UINT64", "QINT8", "QINT16", "BFLOAT16", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "UTF8", };
|
public static final String[] names = { "INHERIT", "BOOL", "FLOAT8", "HALF", "HALF2", "FLOAT", "DOUBLE", "INT8", "INT16", "INT32", "INT64", "UINT8", "UINT16", "UINT32", "UINT64", "QINT8", "QINT16", "BFLOAT16", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "UTF8", "UTF16", "UTF32", };
|
||||||
|
|
||||||
public static String name(int e) { return names[e]; }
|
public static String name(int e) { return names[e]; }
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,4 +22,6 @@ class DType(object):
|
||||||
QINT16 = 16
|
QINT16 = 16
|
||||||
BFLOAT16 = 17
|
BFLOAT16 = 17
|
||||||
UTF8 = 50
|
UTF8 = 50
|
||||||
|
UTF16 = 51
|
||||||
|
UTF32 = 52
|
||||||
|
|
||||||
|
|
|
@ -112,6 +112,14 @@ public struct FlatNode : IFlatbufferObject
|
||||||
public int VarControlDepsLength { get { int o = __p.__offset(44); return o != 0 ? __p.__vector_len(o) : 0; } }
|
public int VarControlDepsLength { get { int o = __p.__offset(44); return o != 0 ? __p.__vector_len(o) : 0; } }
|
||||||
public string ControlDepFor(int j) { int o = __p.__offset(46); return o != 0 ? __p.__string(__p.__vector(o) + j * 4) : null; }
|
public string ControlDepFor(int j) { int o = __p.__offset(46); return o != 0 ? __p.__string(__p.__vector(o) + j * 4) : null; }
|
||||||
public int ControlDepForLength { get { int o = __p.__offset(46); return o != 0 ? __p.__vector_len(o) : 0; } }
|
public int ControlDepForLength { get { int o = __p.__offset(46); return o != 0 ? __p.__vector_len(o) : 0; } }
|
||||||
|
public DType ExtraTypes(int j) { int o = __p.__offset(48); return o != 0 ? (DType)__p.bb.GetSbyte(__p.__vector(o) + j * 1) : (DType)0; }
|
||||||
|
public int ExtraTypesLength { get { int o = __p.__offset(48); return o != 0 ? __p.__vector_len(o) : 0; } }
|
||||||
|
#if ENABLE_SPAN_T
|
||||||
|
public Span<byte> GetExtraTypesBytes() { return __p.__vector_as_span(48); }
|
||||||
|
#else
|
||||||
|
public ArraySegment<byte>? GetExtraTypesBytes() { return __p.__vector_as_arraysegment(48); }
|
||||||
|
#endif
|
||||||
|
public DType[] GetExtraTypesArray() { return __p.__vector_as_array<DType>(48); }
|
||||||
|
|
||||||
public static Offset<FlatNode> CreateFlatNode(FlatBufferBuilder builder,
|
public static Offset<FlatNode> CreateFlatNode(FlatBufferBuilder builder,
|
||||||
int id = 0,
|
int id = 0,
|
||||||
|
@ -135,9 +143,11 @@ public struct FlatNode : IFlatbufferObject
|
||||||
Offset<FlatArray> scalarOffset = default(Offset<FlatArray>),
|
Offset<FlatArray> scalarOffset = default(Offset<FlatArray>),
|
||||||
VectorOffset controlDepsOffset = default(VectorOffset),
|
VectorOffset controlDepsOffset = default(VectorOffset),
|
||||||
VectorOffset varControlDepsOffset = default(VectorOffset),
|
VectorOffset varControlDepsOffset = default(VectorOffset),
|
||||||
VectorOffset controlDepForOffset = default(VectorOffset)) {
|
VectorOffset controlDepForOffset = default(VectorOffset),
|
||||||
builder.StartObject(22);
|
VectorOffset extraTypesOffset = default(VectorOffset)) {
|
||||||
|
builder.StartObject(23);
|
||||||
FlatNode.AddOpNum(builder, opNum);
|
FlatNode.AddOpNum(builder, opNum);
|
||||||
|
FlatNode.AddExtraTypes(builder, extraTypesOffset);
|
||||||
FlatNode.AddControlDepFor(builder, controlDepForOffset);
|
FlatNode.AddControlDepFor(builder, controlDepForOffset);
|
||||||
FlatNode.AddVarControlDeps(builder, varControlDepsOffset);
|
FlatNode.AddVarControlDeps(builder, varControlDepsOffset);
|
||||||
FlatNode.AddControlDeps(builder, controlDepsOffset);
|
FlatNode.AddControlDeps(builder, controlDepsOffset);
|
||||||
|
@ -162,7 +172,7 @@ public struct FlatNode : IFlatbufferObject
|
||||||
return FlatNode.EndFlatNode(builder);
|
return FlatNode.EndFlatNode(builder);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static void StartFlatNode(FlatBufferBuilder builder) { builder.StartObject(22); }
|
public static void StartFlatNode(FlatBufferBuilder builder) { builder.StartObject(23); }
|
||||||
public static void AddId(FlatBufferBuilder builder, int id) { builder.AddInt(0, id, 0); }
|
public static void AddId(FlatBufferBuilder builder, int id) { builder.AddInt(0, id, 0); }
|
||||||
public static void AddName(FlatBufferBuilder builder, StringOffset nameOffset) { builder.AddOffset(1, nameOffset.Value, 0); }
|
public static void AddName(FlatBufferBuilder builder, StringOffset nameOffset) { builder.AddOffset(1, nameOffset.Value, 0); }
|
||||||
public static void AddOpType(FlatBufferBuilder builder, OpType opType) { builder.AddSbyte(2, (sbyte)opType, 0); }
|
public static void AddOpType(FlatBufferBuilder builder, OpType opType) { builder.AddSbyte(2, (sbyte)opType, 0); }
|
||||||
|
@ -224,6 +234,10 @@ public struct FlatNode : IFlatbufferObject
|
||||||
public static VectorOffset CreateControlDepForVector(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); for (int i = data.Length - 1; i >= 0; i--) builder.AddOffset(data[i].Value); return builder.EndVector(); }
|
public static VectorOffset CreateControlDepForVector(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); for (int i = data.Length - 1; i >= 0; i--) builder.AddOffset(data[i].Value); return builder.EndVector(); }
|
||||||
public static VectorOffset CreateControlDepForVectorBlock(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); builder.Add(data); return builder.EndVector(); }
|
public static VectorOffset CreateControlDepForVectorBlock(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); builder.Add(data); return builder.EndVector(); }
|
||||||
public static void StartControlDepForVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(4, numElems, 4); }
|
public static void StartControlDepForVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(4, numElems, 4); }
|
||||||
|
public static void AddExtraTypes(FlatBufferBuilder builder, VectorOffset extraTypesOffset) { builder.AddOffset(22, extraTypesOffset.Value, 0); }
|
||||||
|
public static VectorOffset CreateExtraTypesVector(FlatBufferBuilder builder, DType[] data) { builder.StartVector(1, data.Length, 1); for (int i = data.Length - 1; i >= 0; i--) builder.AddSbyte((sbyte)data[i]); return builder.EndVector(); }
|
||||||
|
public static VectorOffset CreateExtraTypesVectorBlock(FlatBufferBuilder builder, DType[] data) { builder.StartVector(1, data.Length, 1); builder.Add(data); return builder.EndVector(); }
|
||||||
|
public static void StartExtraTypesVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(1, numElems, 1); }
|
||||||
public static Offset<FlatNode> EndFlatNode(FlatBufferBuilder builder) {
|
public static Offset<FlatNode> EndFlatNode(FlatBufferBuilder builder) {
|
||||||
int o = builder.EndObject();
|
int o = builder.EndObject();
|
||||||
return new Offset<FlatNode>(o);
|
return new Offset<FlatNode>(o);
|
||||||
|
|
|
@ -72,6 +72,10 @@ public final class FlatNode extends Table {
|
||||||
public int varControlDepsLength() { int o = __offset(44); return o != 0 ? __vector_len(o) : 0; }
|
public int varControlDepsLength() { int o = __offset(44); return o != 0 ? __vector_len(o) : 0; }
|
||||||
public String controlDepFor(int j) { int o = __offset(46); return o != 0 ? __string(__vector(o) + j * 4) : null; }
|
public String controlDepFor(int j) { int o = __offset(46); return o != 0 ? __string(__vector(o) + j * 4) : null; }
|
||||||
public int controlDepForLength() { int o = __offset(46); return o != 0 ? __vector_len(o) : 0; }
|
public int controlDepForLength() { int o = __offset(46); return o != 0 ? __vector_len(o) : 0; }
|
||||||
|
public byte extraTypes(int j) { int o = __offset(48); return o != 0 ? bb.get(__vector(o) + j * 1) : 0; }
|
||||||
|
public int extraTypesLength() { int o = __offset(48); return o != 0 ? __vector_len(o) : 0; }
|
||||||
|
public ByteBuffer extraTypesAsByteBuffer() { return __vector_as_bytebuffer(48, 1); }
|
||||||
|
public ByteBuffer extraTypesInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 48, 1); }
|
||||||
|
|
||||||
public static int createFlatNode(FlatBufferBuilder builder,
|
public static int createFlatNode(FlatBufferBuilder builder,
|
||||||
int id,
|
int id,
|
||||||
|
@ -95,9 +99,11 @@ public final class FlatNode extends Table {
|
||||||
int scalarOffset,
|
int scalarOffset,
|
||||||
int controlDepsOffset,
|
int controlDepsOffset,
|
||||||
int varControlDepsOffset,
|
int varControlDepsOffset,
|
||||||
int controlDepForOffset) {
|
int controlDepForOffset,
|
||||||
builder.startObject(22);
|
int extraTypesOffset) {
|
||||||
|
builder.startObject(23);
|
||||||
FlatNode.addOpNum(builder, opNum);
|
FlatNode.addOpNum(builder, opNum);
|
||||||
|
FlatNode.addExtraTypes(builder, extraTypesOffset);
|
||||||
FlatNode.addControlDepFor(builder, controlDepForOffset);
|
FlatNode.addControlDepFor(builder, controlDepForOffset);
|
||||||
FlatNode.addVarControlDeps(builder, varControlDepsOffset);
|
FlatNode.addVarControlDeps(builder, varControlDepsOffset);
|
||||||
FlatNode.addControlDeps(builder, controlDepsOffset);
|
FlatNode.addControlDeps(builder, controlDepsOffset);
|
||||||
|
@ -122,7 +128,7 @@ public final class FlatNode extends Table {
|
||||||
return FlatNode.endFlatNode(builder);
|
return FlatNode.endFlatNode(builder);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static void startFlatNode(FlatBufferBuilder builder) { builder.startObject(22); }
|
public static void startFlatNode(FlatBufferBuilder builder) { builder.startObject(23); }
|
||||||
public static void addId(FlatBufferBuilder builder, int id) { builder.addInt(0, id, 0); }
|
public static void addId(FlatBufferBuilder builder, int id) { builder.addInt(0, id, 0); }
|
||||||
public static void addName(FlatBufferBuilder builder, int nameOffset) { builder.addOffset(1, nameOffset, 0); }
|
public static void addName(FlatBufferBuilder builder, int nameOffset) { builder.addOffset(1, nameOffset, 0); }
|
||||||
public static void addOpType(FlatBufferBuilder builder, byte opType) { builder.addByte(2, opType, 0); }
|
public static void addOpType(FlatBufferBuilder builder, byte opType) { builder.addByte(2, opType, 0); }
|
||||||
|
@ -171,6 +177,9 @@ public final class FlatNode extends Table {
|
||||||
public static void addControlDepFor(FlatBufferBuilder builder, int controlDepForOffset) { builder.addOffset(21, controlDepForOffset, 0); }
|
public static void addControlDepFor(FlatBufferBuilder builder, int controlDepForOffset) { builder.addOffset(21, controlDepForOffset, 0); }
|
||||||
public static int createControlDepForVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); }
|
public static int createControlDepForVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); }
|
||||||
public static void startControlDepForVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); }
|
public static void startControlDepForVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); }
|
||||||
|
public static void addExtraTypes(FlatBufferBuilder builder, int extraTypesOffset) { builder.addOffset(22, extraTypesOffset, 0); }
|
||||||
|
public static int createExtraTypesVector(FlatBufferBuilder builder, byte[] data) { builder.startVector(1, data.length, 1); for (int i = data.length - 1; i >= 0; i--) builder.addByte(data[i]); return builder.endVector(); }
|
||||||
|
public static void startExtraTypesVector(FlatBufferBuilder builder, int numElems) { builder.startVector(1, numElems, 1); }
|
||||||
public static int endFlatNode(FlatBufferBuilder builder) {
|
public static int endFlatNode(FlatBufferBuilder builder) {
|
||||||
int o = builder.endObject();
|
int o = builder.endObject();
|
||||||
return o;
|
return o;
|
||||||
|
|
|
@ -339,7 +339,29 @@ class FlatNode(object):
|
||||||
return self._tab.VectorLen(o)
|
return self._tab.VectorLen(o)
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
def FlatNodeStart(builder): builder.StartObject(22)
|
# FlatNode
|
||||||
|
def ExtraTypes(self, j):
|
||||||
|
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(48))
|
||||||
|
if o != 0:
|
||||||
|
a = self._tab.Vector(o)
|
||||||
|
return self._tab.Get(flatbuffers.number_types.Int8Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 1))
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# FlatNode
|
||||||
|
def ExtraTypesAsNumpy(self):
|
||||||
|
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(48))
|
||||||
|
if o != 0:
|
||||||
|
return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int8Flags, o)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# FlatNode
|
||||||
|
def ExtraTypesLength(self):
|
||||||
|
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(48))
|
||||||
|
if o != 0:
|
||||||
|
return self._tab.VectorLen(o)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def FlatNodeStart(builder): builder.StartObject(23)
|
||||||
def FlatNodeAddId(builder, id): builder.PrependInt32Slot(0, id, 0)
|
def FlatNodeAddId(builder, id): builder.PrependInt32Slot(0, id, 0)
|
||||||
def FlatNodeAddName(builder, name): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(name), 0)
|
def FlatNodeAddName(builder, name): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(name), 0)
|
||||||
def FlatNodeAddOpType(builder, opType): builder.PrependInt8Slot(2, opType, 0)
|
def FlatNodeAddOpType(builder, opType): builder.PrependInt8Slot(2, opType, 0)
|
||||||
|
@ -375,4 +397,6 @@ def FlatNodeAddVarControlDeps(builder, varControlDeps): builder.PrependUOffsetTR
|
||||||
def FlatNodeStartVarControlDepsVector(builder, numElems): return builder.StartVector(4, numElems, 4)
|
def FlatNodeStartVarControlDepsVector(builder, numElems): return builder.StartVector(4, numElems, 4)
|
||||||
def FlatNodeAddControlDepFor(builder, controlDepFor): builder.PrependUOffsetTRelativeSlot(21, flatbuffers.number_types.UOffsetTFlags.py_type(controlDepFor), 0)
|
def FlatNodeAddControlDepFor(builder, controlDepFor): builder.PrependUOffsetTRelativeSlot(21, flatbuffers.number_types.UOffsetTFlags.py_type(controlDepFor), 0)
|
||||||
def FlatNodeStartControlDepForVector(builder, numElems): return builder.StartVector(4, numElems, 4)
|
def FlatNodeStartControlDepForVector(builder, numElems): return builder.StartVector(4, numElems, 4)
|
||||||
|
def FlatNodeAddExtraTypes(builder, extraTypes): builder.PrependUOffsetTRelativeSlot(22, flatbuffers.number_types.UOffsetTFlags.py_type(extraTypes), 0)
|
||||||
|
def FlatNodeStartExtraTypesVector(builder, numElems): return builder.StartVector(1, numElems, 1)
|
||||||
def FlatNodeEnd(builder): return builder.EndObject()
|
def FlatNodeEnd(builder): return builder.EndObject()
|
||||||
|
|
|
@ -26,7 +26,7 @@ public struct UIVariable : IFlatbufferObject
|
||||||
#endif
|
#endif
|
||||||
public byte[] GetNameArray() { return __p.__vector_as_array<byte>(6); }
|
public byte[] GetNameArray() { return __p.__vector_as_array<byte>(6); }
|
||||||
public VarType Type { get { int o = __p.__offset(8); return o != 0 ? (VarType)__p.bb.GetSbyte(o + __p.bb_pos) : VarType.VARIABLE; } }
|
public VarType Type { get { int o = __p.__offset(8); return o != 0 ? (VarType)__p.bb.GetSbyte(o + __p.bb_pos) : VarType.VARIABLE; } }
|
||||||
public DataType Datatype { get { int o = __p.__offset(10); return o != 0 ? (DataType)__p.bb.GetSbyte(o + __p.bb_pos) : DataType.INHERIT; } }
|
public DType Datatype { get { int o = __p.__offset(10); return o != 0 ? (DType)__p.bb.GetSbyte(o + __p.bb_pos) : DType.INHERIT; } }
|
||||||
public long Shape(int j) { int o = __p.__offset(12); return o != 0 ? __p.bb.GetLong(__p.__vector(o) + j * 8) : (long)0; }
|
public long Shape(int j) { int o = __p.__offset(12); return o != 0 ? __p.bb.GetLong(__p.__vector(o) + j * 8) : (long)0; }
|
||||||
public int ShapeLength { get { int o = __p.__offset(12); return o != 0 ? __p.__vector_len(o) : 0; } }
|
public int ShapeLength { get { int o = __p.__offset(12); return o != 0 ? __p.__vector_len(o) : 0; } }
|
||||||
#if ENABLE_SPAN_T
|
#if ENABLE_SPAN_T
|
||||||
|
@ -70,7 +70,7 @@ public struct UIVariable : IFlatbufferObject
|
||||||
Offset<IntPair> idOffset = default(Offset<IntPair>),
|
Offset<IntPair> idOffset = default(Offset<IntPair>),
|
||||||
StringOffset nameOffset = default(StringOffset),
|
StringOffset nameOffset = default(StringOffset),
|
||||||
VarType type = VarType.VARIABLE,
|
VarType type = VarType.VARIABLE,
|
||||||
DataType datatype = DataType.INHERIT,
|
DType datatype = DType.INHERIT,
|
||||||
VectorOffset shapeOffset = default(VectorOffset),
|
VectorOffset shapeOffset = default(VectorOffset),
|
||||||
VectorOffset controlDepsOffset = default(VectorOffset),
|
VectorOffset controlDepsOffset = default(VectorOffset),
|
||||||
StringOffset outputOfOpOffset = default(StringOffset),
|
StringOffset outputOfOpOffset = default(StringOffset),
|
||||||
|
@ -101,7 +101,7 @@ public struct UIVariable : IFlatbufferObject
|
||||||
public static void AddId(FlatBufferBuilder builder, Offset<IntPair> idOffset) { builder.AddOffset(0, idOffset.Value, 0); }
|
public static void AddId(FlatBufferBuilder builder, Offset<IntPair> idOffset) { builder.AddOffset(0, idOffset.Value, 0); }
|
||||||
public static void AddName(FlatBufferBuilder builder, StringOffset nameOffset) { builder.AddOffset(1, nameOffset.Value, 0); }
|
public static void AddName(FlatBufferBuilder builder, StringOffset nameOffset) { builder.AddOffset(1, nameOffset.Value, 0); }
|
||||||
public static void AddType(FlatBufferBuilder builder, VarType type) { builder.AddSbyte(2, (sbyte)type, 0); }
|
public static void AddType(FlatBufferBuilder builder, VarType type) { builder.AddSbyte(2, (sbyte)type, 0); }
|
||||||
public static void AddDatatype(FlatBufferBuilder builder, DataType datatype) { builder.AddSbyte(3, (sbyte)datatype, 0); }
|
public static void AddDatatype(FlatBufferBuilder builder, DType datatype) { builder.AddSbyte(3, (sbyte)datatype, 0); }
|
||||||
public static void AddShape(FlatBufferBuilder builder, VectorOffset shapeOffset) { builder.AddOffset(4, shapeOffset.Value, 0); }
|
public static void AddShape(FlatBufferBuilder builder, VectorOffset shapeOffset) { builder.AddOffset(4, shapeOffset.Value, 0); }
|
||||||
public static VectorOffset CreateShapeVector(FlatBufferBuilder builder, long[] data) { builder.StartVector(8, data.Length, 8); for (int i = data.Length - 1; i >= 0; i--) builder.AddLong(data[i]); return builder.EndVector(); }
|
public static VectorOffset CreateShapeVector(FlatBufferBuilder builder, long[] data) { builder.StartVector(8, data.Length, 8); for (int i = data.Length - 1; i >= 0; i--) builder.AddLong(data[i]); return builder.EndVector(); }
|
||||||
public static VectorOffset CreateShapeVectorBlock(FlatBufferBuilder builder, long[] data) { builder.StartVector(8, data.Length, 8); builder.Add(data); return builder.EndVector(); }
|
public static VectorOffset CreateShapeVectorBlock(FlatBufferBuilder builder, long[] data) { builder.StartVector(8, data.Length, 8); builder.Add(data); return builder.EndVector(); }
|
||||||
|
|
|
@ -38,7 +38,8 @@ struct FlatNode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
||||||
VT_SCALAR = 40,
|
VT_SCALAR = 40,
|
||||||
VT_CONTROLDEPS = 42,
|
VT_CONTROLDEPS = 42,
|
||||||
VT_VARCONTROLDEPS = 44,
|
VT_VARCONTROLDEPS = 44,
|
||||||
VT_CONTROLDEPFOR = 46
|
VT_CONTROLDEPFOR = 46,
|
||||||
|
VT_EXTRATYPES = 48
|
||||||
};
|
};
|
||||||
int32_t id() const {
|
int32_t id() const {
|
||||||
return GetField<int32_t>(VT_ID, 0);
|
return GetField<int32_t>(VT_ID, 0);
|
||||||
|
@ -106,6 +107,9 @@ struct FlatNode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
||||||
const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *controlDepFor() const {
|
const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *controlDepFor() const {
|
||||||
return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_CONTROLDEPFOR);
|
return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_CONTROLDEPFOR);
|
||||||
}
|
}
|
||||||
|
const flatbuffers::Vector<int8_t> *extraTypes() const {
|
||||||
|
return GetPointer<const flatbuffers::Vector<int8_t> *>(VT_EXTRATYPES);
|
||||||
|
}
|
||||||
bool Verify(flatbuffers::Verifier &verifier) const {
|
bool Verify(flatbuffers::Verifier &verifier) const {
|
||||||
return VerifyTableStart(verifier) &&
|
return VerifyTableStart(verifier) &&
|
||||||
VerifyField<int32_t>(verifier, VT_ID) &&
|
VerifyField<int32_t>(verifier, VT_ID) &&
|
||||||
|
@ -153,6 +157,8 @@ struct FlatNode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
||||||
VerifyOffset(verifier, VT_CONTROLDEPFOR) &&
|
VerifyOffset(verifier, VT_CONTROLDEPFOR) &&
|
||||||
verifier.VerifyVector(controlDepFor()) &&
|
verifier.VerifyVector(controlDepFor()) &&
|
||||||
verifier.VerifyVectorOfStrings(controlDepFor()) &&
|
verifier.VerifyVectorOfStrings(controlDepFor()) &&
|
||||||
|
VerifyOffset(verifier, VT_EXTRATYPES) &&
|
||||||
|
verifier.VerifyVector(extraTypes()) &&
|
||||||
verifier.EndTable();
|
verifier.EndTable();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -226,6 +232,9 @@ struct FlatNodeBuilder {
|
||||||
void add_controlDepFor(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> controlDepFor) {
|
void add_controlDepFor(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> controlDepFor) {
|
||||||
fbb_.AddOffset(FlatNode::VT_CONTROLDEPFOR, controlDepFor);
|
fbb_.AddOffset(FlatNode::VT_CONTROLDEPFOR, controlDepFor);
|
||||||
}
|
}
|
||||||
|
void add_extraTypes(flatbuffers::Offset<flatbuffers::Vector<int8_t>> extraTypes) {
|
||||||
|
fbb_.AddOffset(FlatNode::VT_EXTRATYPES, extraTypes);
|
||||||
|
}
|
||||||
explicit FlatNodeBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
explicit FlatNodeBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
||||||
: fbb_(_fbb) {
|
: fbb_(_fbb) {
|
||||||
start_ = fbb_.StartTable();
|
start_ = fbb_.StartTable();
|
||||||
|
@ -261,9 +270,11 @@ inline flatbuffers::Offset<FlatNode> CreateFlatNode(
|
||||||
flatbuffers::Offset<FlatArray> scalar = 0,
|
flatbuffers::Offset<FlatArray> scalar = 0,
|
||||||
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> controlDeps = 0,
|
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> controlDeps = 0,
|
||||||
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> varControlDeps = 0,
|
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> varControlDeps = 0,
|
||||||
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> controlDepFor = 0) {
|
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> controlDepFor = 0,
|
||||||
|
flatbuffers::Offset<flatbuffers::Vector<int8_t>> extraTypes = 0) {
|
||||||
FlatNodeBuilder builder_(_fbb);
|
FlatNodeBuilder builder_(_fbb);
|
||||||
builder_.add_opNum(opNum);
|
builder_.add_opNum(opNum);
|
||||||
|
builder_.add_extraTypes(extraTypes);
|
||||||
builder_.add_controlDepFor(controlDepFor);
|
builder_.add_controlDepFor(controlDepFor);
|
||||||
builder_.add_varControlDeps(varControlDeps);
|
builder_.add_varControlDeps(varControlDeps);
|
||||||
builder_.add_controlDeps(controlDeps);
|
builder_.add_controlDeps(controlDeps);
|
||||||
|
@ -311,7 +322,8 @@ inline flatbuffers::Offset<FlatNode> CreateFlatNodeDirect(
|
||||||
flatbuffers::Offset<FlatArray> scalar = 0,
|
flatbuffers::Offset<FlatArray> scalar = 0,
|
||||||
const std::vector<flatbuffers::Offset<flatbuffers::String>> *controlDeps = nullptr,
|
const std::vector<flatbuffers::Offset<flatbuffers::String>> *controlDeps = nullptr,
|
||||||
const std::vector<flatbuffers::Offset<flatbuffers::String>> *varControlDeps = nullptr,
|
const std::vector<flatbuffers::Offset<flatbuffers::String>> *varControlDeps = nullptr,
|
||||||
const std::vector<flatbuffers::Offset<flatbuffers::String>> *controlDepFor = nullptr) {
|
const std::vector<flatbuffers::Offset<flatbuffers::String>> *controlDepFor = nullptr,
|
||||||
|
const std::vector<int8_t> *extraTypes = nullptr) {
|
||||||
return nd4j::graph::CreateFlatNode(
|
return nd4j::graph::CreateFlatNode(
|
||||||
_fbb,
|
_fbb,
|
||||||
id,
|
id,
|
||||||
|
@ -335,7 +347,8 @@ inline flatbuffers::Offset<FlatNode> CreateFlatNodeDirect(
|
||||||
scalar,
|
scalar,
|
||||||
controlDeps ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*controlDeps) : 0,
|
controlDeps ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*controlDeps) : 0,
|
||||||
varControlDeps ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*varControlDeps) : 0,
|
varControlDeps ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*varControlDeps) : 0,
|
||||||
controlDepFor ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*controlDepFor) : 0);
|
controlDepFor ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*controlDepFor) : 0,
|
||||||
|
extraTypes ? _fbb.CreateVector<int8_t>(*extraTypes) : 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline const nd4j::graph::FlatNode *GetFlatNode(const void *buf) {
|
inline const nd4j::graph::FlatNode *GetFlatNode(const void *buf) {
|
||||||
|
|
|
@ -398,11 +398,36 @@ nd4j.graph.FlatNode.prototype.controlDepForLength = function() {
|
||||||
return offset ? this.bb.__vector_len(this.bb_pos + offset) : 0;
|
return offset ? this.bb.__vector_len(this.bb_pos + offset) : 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param {number} index
|
||||||
|
* @returns {nd4j.graph.DType}
|
||||||
|
*/
|
||||||
|
nd4j.graph.FlatNode.prototype.extraTypes = function(index) {
|
||||||
|
var offset = this.bb.__offset(this.bb_pos, 48);
|
||||||
|
return offset ? /** @type {nd4j.graph.DType} */ (this.bb.readInt8(this.bb.__vector(this.bb_pos + offset) + index)) : /** @type {nd4j.graph.DType} */ (0);
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @returns {number}
|
||||||
|
*/
|
||||||
|
nd4j.graph.FlatNode.prototype.extraTypesLength = function() {
|
||||||
|
var offset = this.bb.__offset(this.bb_pos, 48);
|
||||||
|
return offset ? this.bb.__vector_len(this.bb_pos + offset) : 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @returns {Int8Array}
|
||||||
|
*/
|
||||||
|
nd4j.graph.FlatNode.prototype.extraTypesArray = function() {
|
||||||
|
var offset = this.bb.__offset(this.bb_pos, 48);
|
||||||
|
return offset ? new Int8Array(this.bb.bytes().buffer, this.bb.bytes().byteOffset + this.bb.__vector(this.bb_pos + offset), this.bb.__vector_len(this.bb_pos + offset)) : null;
|
||||||
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param {flatbuffers.Builder} builder
|
* @param {flatbuffers.Builder} builder
|
||||||
*/
|
*/
|
||||||
nd4j.graph.FlatNode.startFlatNode = function(builder) {
|
nd4j.graph.FlatNode.startFlatNode = function(builder) {
|
||||||
builder.startObject(22);
|
builder.startObject(23);
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -854,6 +879,35 @@ nd4j.graph.FlatNode.startControlDepForVector = function(builder, numElems) {
|
||||||
builder.startVector(4, numElems, 4);
|
builder.startVector(4, numElems, 4);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param {flatbuffers.Builder} builder
|
||||||
|
* @param {flatbuffers.Offset} extraTypesOffset
|
||||||
|
*/
|
||||||
|
nd4j.graph.FlatNode.addExtraTypes = function(builder, extraTypesOffset) {
|
||||||
|
builder.addFieldOffset(22, extraTypesOffset, 0);
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param {flatbuffers.Builder} builder
|
||||||
|
* @param {Array.<nd4j.graph.DType>} data
|
||||||
|
* @returns {flatbuffers.Offset}
|
||||||
|
*/
|
||||||
|
nd4j.graph.FlatNode.createExtraTypesVector = function(builder, data) {
|
||||||
|
builder.startVector(1, data.length, 1);
|
||||||
|
for (var i = data.length - 1; i >= 0; i--) {
|
||||||
|
builder.addInt8(data[i]);
|
||||||
|
}
|
||||||
|
return builder.endVector();
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param {flatbuffers.Builder} builder
|
||||||
|
* @param {number} numElems
|
||||||
|
*/
|
||||||
|
nd4j.graph.FlatNode.startExtraTypesVector = function(builder, numElems) {
|
||||||
|
builder.startVector(1, numElems, 1);
|
||||||
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param {flatbuffers.Builder} builder
|
* @param {flatbuffers.Builder} builder
|
||||||
* @returns {flatbuffers.Offset}
|
* @returns {flatbuffers.Offset}
|
||||||
|
|
|
@ -266,8 +266,8 @@ struct UIVariable FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
||||||
VarType type() const {
|
VarType type() const {
|
||||||
return static_cast<VarType>(GetField<int8_t>(VT_TYPE, 0));
|
return static_cast<VarType>(GetField<int8_t>(VT_TYPE, 0));
|
||||||
}
|
}
|
||||||
DataType datatype() const {
|
DType datatype() const {
|
||||||
return static_cast<DataType>(GetField<int8_t>(VT_DATATYPE, 0));
|
return static_cast<DType>(GetField<int8_t>(VT_DATATYPE, 0));
|
||||||
}
|
}
|
||||||
const flatbuffers::Vector<int64_t> *shape() const {
|
const flatbuffers::Vector<int64_t> *shape() const {
|
||||||
return GetPointer<const flatbuffers::Vector<int64_t> *>(VT_SHAPE);
|
return GetPointer<const flatbuffers::Vector<int64_t> *>(VT_SHAPE);
|
||||||
|
@ -342,7 +342,7 @@ struct UIVariableBuilder {
|
||||||
void add_type(VarType type) {
|
void add_type(VarType type) {
|
||||||
fbb_.AddElement<int8_t>(UIVariable::VT_TYPE, static_cast<int8_t>(type), 0);
|
fbb_.AddElement<int8_t>(UIVariable::VT_TYPE, static_cast<int8_t>(type), 0);
|
||||||
}
|
}
|
||||||
void add_datatype(DataType datatype) {
|
void add_datatype(DType datatype) {
|
||||||
fbb_.AddElement<int8_t>(UIVariable::VT_DATATYPE, static_cast<int8_t>(datatype), 0);
|
fbb_.AddElement<int8_t>(UIVariable::VT_DATATYPE, static_cast<int8_t>(datatype), 0);
|
||||||
}
|
}
|
||||||
void add_shape(flatbuffers::Offset<flatbuffers::Vector<int64_t>> shape) {
|
void add_shape(flatbuffers::Offset<flatbuffers::Vector<int64_t>> shape) {
|
||||||
|
@ -389,7 +389,7 @@ inline flatbuffers::Offset<UIVariable> CreateUIVariable(
|
||||||
flatbuffers::Offset<IntPair> id = 0,
|
flatbuffers::Offset<IntPair> id = 0,
|
||||||
flatbuffers::Offset<flatbuffers::String> name = 0,
|
flatbuffers::Offset<flatbuffers::String> name = 0,
|
||||||
VarType type = VarType_VARIABLE,
|
VarType type = VarType_VARIABLE,
|
||||||
DataType datatype = DataType_INHERIT,
|
DType datatype = DType_INHERIT,
|
||||||
flatbuffers::Offset<flatbuffers::Vector<int64_t>> shape = 0,
|
flatbuffers::Offset<flatbuffers::Vector<int64_t>> shape = 0,
|
||||||
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> controlDeps = 0,
|
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> controlDeps = 0,
|
||||||
flatbuffers::Offset<flatbuffers::String> outputOfOp = 0,
|
flatbuffers::Offset<flatbuffers::String> outputOfOp = 0,
|
||||||
|
@ -421,7 +421,7 @@ inline flatbuffers::Offset<UIVariable> CreateUIVariableDirect(
|
||||||
flatbuffers::Offset<IntPair> id = 0,
|
flatbuffers::Offset<IntPair> id = 0,
|
||||||
const char *name = nullptr,
|
const char *name = nullptr,
|
||||||
VarType type = VarType_VARIABLE,
|
VarType type = VarType_VARIABLE,
|
||||||
DataType datatype = DataType_INHERIT,
|
DType datatype = DType_INHERIT,
|
||||||
const std::vector<int64_t> *shape = nullptr,
|
const std::vector<int64_t> *shape = nullptr,
|
||||||
const std::vector<flatbuffers::Offset<flatbuffers::String>> *controlDeps = nullptr,
|
const std::vector<flatbuffers::Offset<flatbuffers::String>> *controlDeps = nullptr,
|
||||||
const char *outputOfOp = nullptr,
|
const char *outputOfOp = nullptr,
|
||||||
|
|
|
@ -503,11 +503,11 @@ nd4j.graph.UIVariable.prototype.type = function() {
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @returns {nd4j.graph.DataType}
|
* @returns {nd4j.graph.DType}
|
||||||
*/
|
*/
|
||||||
nd4j.graph.UIVariable.prototype.datatype = function() {
|
nd4j.graph.UIVariable.prototype.datatype = function() {
|
||||||
var offset = this.bb.__offset(this.bb_pos, 10);
|
var offset = this.bb.__offset(this.bb_pos, 10);
|
||||||
return offset ? /** @type {nd4j.graph.DataType} */ (this.bb.readInt8(this.bb_pos + offset)) : nd4j.graph.DataType.INHERIT;
|
return offset ? /** @type {nd4j.graph.DType} */ (this.bb.readInt8(this.bb_pos + offset)) : nd4j.graph.DType.INHERIT;
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -668,10 +668,10 @@ nd4j.graph.UIVariable.addType = function(builder, type) {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param {flatbuffers.Builder} builder
|
* @param {flatbuffers.Builder} builder
|
||||||
* @param {nd4j.graph.DataType} datatype
|
* @param {nd4j.graph.DType} datatype
|
||||||
*/
|
*/
|
||||||
nd4j.graph.UIVariable.addDatatype = function(builder, datatype) {
|
nd4j.graph.UIVariable.addDatatype = function(builder, datatype) {
|
||||||
builder.addFieldInt8(3, datatype, nd4j.graph.DataType.INHERIT);
|
builder.addFieldInt8(3, datatype, nd4j.graph.DType.INHERIT);
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -551,6 +551,18 @@ namespace nd4j {
|
||||||
bool Context::isInference() {
|
bool Context::isInference() {
|
||||||
return _execMode == samediff::ExecutionMode::MODE_INFERENCE;
|
return _execMode == samediff::ExecutionMode::MODE_INFERENCE;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Context::setDArguments(nd4j::DataType *arguments, int numberOfArguments) {
|
||||||
|
_dArgs.clear();
|
||||||
|
for (int e = 0; e < numberOfArguments; e++)
|
||||||
|
_dArgs.emplace_back(arguments[e]);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Context::setDArguments(const std::vector<nd4j::DataType> &dArgs) {
|
||||||
|
_dArgs.clear();
|
||||||
|
for (auto d:dArgs)
|
||||||
|
_dArgs.emplace_back(d);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -173,5 +173,13 @@ namespace nd4j {
|
||||||
|
|
||||||
return clone;
|
return clone;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<nd4j::DataType> *ContextPrototype::getDArguments() {
|
||||||
|
return &_dArgs;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t ContextPrototype::numD() {
|
||||||
|
return _dArgs.size();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -49,11 +49,10 @@ namespace nd4j {
|
||||||
delete[] newShape;
|
delete[] newShape;
|
||||||
return NDArrayFactory::empty_(dtype, nullptr);
|
return NDArrayFactory::empty_(dtype, nullptr);
|
||||||
}
|
}
|
||||||
|
// TODO fix UTF16 and UTF32
|
||||||
if (dtype == UTF8) {
|
if (dtype == UTF8) {
|
||||||
bool isBe = BitwiseUtils::isBE();
|
bool isBe = BitwiseUtils::isBE();
|
||||||
bool canKeep = (isBe && flatArray->byteOrder() == nd4j::graph::ByteOrder_BE) || (!isBe && flatArray->byteOrder() == nd4j::graph::ByteOrder_LE);
|
bool canKeep = (isBe && flatArray->byteOrder() == nd4j::graph::ByteOrder_BE) || (!isBe && flatArray->byteOrder() == nd4j::graph::ByteOrder_LE);
|
||||||
auto order = shape::order(newShape);
|
|
||||||
|
|
||||||
std::vector<std::string> substrings(length);
|
std::vector<std::string> substrings(length);
|
||||||
std::vector<Nd4jLong> shapeVector(rank);
|
std::vector<Nd4jLong> shapeVector(rank);
|
||||||
|
@ -88,8 +87,8 @@ namespace nd4j {
|
||||||
|
|
||||||
delete[] offsets;
|
delete[] offsets;
|
||||||
delete[] newShape;
|
delete[] newShape;
|
||||||
|
// string order always 'c'
|
||||||
return NDArrayFactory::string_(order, shapeVector, substrings);
|
return NDArrayFactory::string_(shapeVector, substrings);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -587,6 +587,12 @@ namespace nd4j {
|
||||||
block->getIArguments()->emplace_back(node->extraInteger()->Get(e));
|
block->getIArguments()->emplace_back(node->extraInteger()->Get(e));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (node->extraTypes() != nullptr && node->extraTypes()->size() > 0) {
|
||||||
|
for (int e = 0; e < (int) node->extraTypes()->size(); e++) {
|
||||||
|
block->getDArguments()->emplace_back((nd4j::DataType) node->extraTypes()->Get(e));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
this->setContextPrototype(block);
|
this->setContextPrototype(block);
|
||||||
this->setCustomOp(Node::buildOpByType(_opType, (int) node->input()->size(), (int) block->getIArguments()->size(), (int) block->getTArguments()->size(), (int) _opNum, &_scalar));
|
this->setCustomOp(Node::buildOpByType(_opType, (int) node->input()->size(), (int) block->getIArguments()->size(), (int) block->getTArguments()->size(), (int) _opNum, &_scalar));
|
||||||
block->setOpDescriptor(this->getCustomOp()->getOpDescriptor());
|
block->setOpDescriptor(this->getCustomOp()->getOpDescriptor());
|
||||||
|
@ -618,6 +624,12 @@ namespace nd4j {
|
||||||
block->getIArguments()->emplace_back(node->extraInteger()->Get(e));
|
block->getIArguments()->emplace_back(node->extraInteger()->Get(e));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (node->extraTypes() != nullptr && node->extraTypes()->size() > 0) {
|
||||||
|
for (int e = 0; e < (int) node->extraTypes()->size(); e++) {
|
||||||
|
block->getDArguments()->emplace_back((nd4j::DataType) node->extraTypes()->Get(e));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
this->setContextPrototype(block);
|
this->setContextPrototype(block);
|
||||||
|
|
||||||
this->setCustomOp(Node::buildOpByType(_opType, (int) node->inputPaired()->size(), (int) block->getIArguments()->size(), (int) block->getTArguments()->size(), (int) _opNum, &_scalar));
|
this->setCustomOp(Node::buildOpByType(_opType, (int) node->inputPaired()->size(), (int) block->getIArguments()->size(), (int) block->getTArguments()->size(), (int) _opNum, &_scalar));
|
||||||
|
@ -652,6 +664,12 @@ namespace nd4j {
|
||||||
block->getBArguments()->push_back(node->extraBools()->Get(e));
|
block->getBArguments()->push_back(node->extraBools()->Get(e));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (node->extraTypes() != nullptr && node->extraTypes()->size() > 0) {
|
||||||
|
for (int e = 0; e < (int) node->extraTypes()->size(); e++) {
|
||||||
|
block->getDArguments()->emplace_back((nd4j::DataType) node->extraTypes()->Get(e));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for (auto v: _dimensions)
|
for (auto v: _dimensions)
|
||||||
block->getAxis()->emplace_back(v);
|
block->getAxis()->emplace_back(v);
|
||||||
|
|
||||||
|
|
|
@ -58,6 +58,8 @@ table FlatNode {
|
||||||
varControlDeps:[string];
|
varControlDeps:[string];
|
||||||
controlDepFor:[string];
|
controlDepFor:[string];
|
||||||
|
|
||||||
|
// DArgs
|
||||||
|
extraTypes:[DType];
|
||||||
}
|
}
|
||||||
|
|
||||||
root_type FlatNode;
|
root_type FlatNode;
|
|
@ -171,7 +171,10 @@ namespace nd4j {
|
||||||
* @param numStrings
|
* @param numStrings
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
static Nd4jLong stringBufferHeaderRequirements(Nd4jLong numStrings);
|
static FORCEINLINE Nd4jLong stringBufferHeaderRequirements(Nd4jLong numStrings) {
|
||||||
|
// we store +1 offset
|
||||||
|
return (numStrings + 1) * sizeof(Nd4jLong);
|
||||||
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* check whether arr1/arr2 is sub-array of arr2/arr1,
|
* check whether arr1/arr2 is sub-array of arr2/arr1,
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2019-2020 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -16,6 +17,7 @@
|
||||||
|
|
||||||
//
|
//
|
||||||
// Created by raver119 on 20/04/18.
|
// Created by raver119 on 20/04/18.
|
||||||
|
// @author Oleg Semeniv <oleg.semeniv@gmail.com>
|
||||||
//
|
//
|
||||||
|
|
||||||
#ifndef LIBND4J_STRINGUTILS_H
|
#ifndef LIBND4J_STRINGUTILS_H
|
||||||
|
@ -27,6 +29,7 @@
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <NDArray.h>
|
#include <NDArray.h>
|
||||||
|
#include <unicode.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
class ND4J_EXPORT StringUtils {
|
class ND4J_EXPORT StringUtils {
|
||||||
|
@ -85,6 +88,55 @@ namespace nd4j {
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
static std::vector<std::string> split(const std::string &haystack, const std::string &delimiter);
|
static std::vector<std::string> split(const std::string &haystack, const std::string &delimiter);
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method convert u8 string to u16
|
||||||
|
* @param const reference to input string
|
||||||
|
* @param reference to output u16string
|
||||||
|
* @return boolean status
|
||||||
|
*/
|
||||||
|
static bool u8StringToU16String(const std::string& u8, std::u16string& u16);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method convert u8 string to u32
|
||||||
|
* @param const reference to input string
|
||||||
|
* @param reference to output u32string
|
||||||
|
* @return boolean status
|
||||||
|
*/
|
||||||
|
static bool u8StringToU32String(const std::string& u8, std::u32string& u32);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method convert u16 string to u32
|
||||||
|
* @param const reference to input u16string
|
||||||
|
* @param reference to output u32string
|
||||||
|
* @return boolean status
|
||||||
|
*/
|
||||||
|
static bool u16StringToU32String(const std::u16string& u16, std::u32string& u32);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method convert u16 string to u8 string
|
||||||
|
* @param const reference to input u16string
|
||||||
|
* @param reference to output string
|
||||||
|
* @return boolean status
|
||||||
|
*/
|
||||||
|
static bool u16StringToU8String(const std::u16string& u16, std::string& u8);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method convert u32 string to u16 string
|
||||||
|
* @param const reference to input u32string
|
||||||
|
* @param reference to output u16string
|
||||||
|
* @return boolean status
|
||||||
|
*/
|
||||||
|
static bool u32StringToU16String(const std::u32string& u32, std::u16string& u16);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method convert u32 string to u8 string
|
||||||
|
* @param const reference to input u32string
|
||||||
|
* @param reference to output string
|
||||||
|
* @return boolean status
|
||||||
|
*/
|
||||||
|
static bool u32StringToU8String(const std::u32string& u32, std::string& u8);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -40,7 +40,7 @@ namespace nd4j {
|
||||||
|
|
||||||
NDArray projected('c', {numHeads * projectionMatrix->sizeAt(1), (miniBatchSize * seqLength)}, input->dataType(), context); //[nHeads*hS, batch*timeSteps]
|
NDArray projected('c', {numHeads * projectionMatrix->sizeAt(1), (miniBatchSize * seqLength)}, input->dataType(), context); //[nHeads*hS, batch*timeSteps]
|
||||||
nd4j::ops::matmul mmul;
|
nd4j::ops::matmul mmul;
|
||||||
mmul.execute({&projectionPrep, &inputPrep}, {&projected}, {}, {}, {});
|
mmul.execute({&projectionPrep, &inputPrep}, {&projected});
|
||||||
|
|
||||||
projected.reshapei({numHeads, projectedSize, miniBatchSize, seqLength});
|
projected.reshapei({numHeads, projectedSize, miniBatchSize, seqLength});
|
||||||
projected.permutei({2, 0, 1, 3}); //[minibatch, numHeads, projectedSize, seqLength]
|
projected.permutei({2, 0, 1, 3}); //[minibatch, numHeads, projectedSize, seqLength]
|
||||||
|
@ -66,7 +66,7 @@ namespace nd4j {
|
||||||
nd4j::ops::matmul_bp mmulBp;
|
nd4j::ops::matmul_bp mmulBp;
|
||||||
NDArray dLdProjectionPrep(projectionPrep.shapeInfo(), false, context);
|
NDArray dLdProjectionPrep(projectionPrep.shapeInfo(), false, context);
|
||||||
NDArray dLdInputPrep(inputPrep.shapeInfo(), false, context);
|
NDArray dLdInputPrep(inputPrep.shapeInfo(), false, context);
|
||||||
mmulBp.execute({&projectionPrep, &inputPrep, &epsReshaped}, {&dLdProjectionPrep, &dLdInputPrep}, {}, {}, {});
|
mmulBp.execute({&projectionPrep, &inputPrep, &epsReshaped}, std::vector<NDArray*>{&dLdProjectionPrep, &dLdInputPrep}, {}, {}, {});
|
||||||
|
|
||||||
dLdProjectionPrep.reshapei({numHeads, projectionMatrix->sizeAt(1), projectionMatrix->sizeAt(2)});
|
dLdProjectionPrep.reshapei({numHeads, projectionMatrix->sizeAt(1), projectionMatrix->sizeAt(2)});
|
||||||
dLdProjectionMatrix->assign(dLdProjectionPrep);
|
dLdProjectionMatrix->assign(dLdProjectionPrep);
|
||||||
|
|
|
@ -1019,15 +1019,6 @@ std::vector<int> ShapeUtils::tadAxesForSimpleBroadcast(const NDArray& max, const
|
||||||
return numOfMinTads == 1 ? maxTadDims : std::vector<int>();
|
return numOfMinTads == 1 ? maxTadDims : std::vector<int>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
Nd4jLong ShapeUtils::stringBufferHeaderRequirements(Nd4jLong numStrings) {
|
|
||||||
// we store +1 offset
|
|
||||||
auto base = numStrings + 1;
|
|
||||||
|
|
||||||
// since we return number of bytes...
|
|
||||||
return base * sizeof(Nd4jLong);
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
/*
|
/*
|
||||||
bool ShapeUtils::isSubArrayCase(const NDArray& arr1, const NDArray& arr2, std::vector<int>& sameDims) {
|
bool ShapeUtils::isSubArrayCase(const NDArray& arr1, const NDArray& arr2, std::vector<int>& sameDims) {
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2019-2020 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -16,6 +17,7 @@
|
||||||
|
|
||||||
//
|
//
|
||||||
// Created by raver119 on 20/04/18.
|
// Created by raver119 on 20/04/18.
|
||||||
|
// @author Oleg Semeniv <oleg.semeniv@gmail.com>
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <helpers/StringUtils.h>
|
#include <helpers/StringUtils.h>
|
||||||
|
@ -49,13 +51,8 @@ namespace nd4j {
|
||||||
if (!array.isS())
|
if (!array.isS())
|
||||||
throw nd4j::datatype_exception::build("StringUtils::byteLength expects one of String types;", array.dataType());
|
throw nd4j::datatype_exception::build("StringUtils::byteLength expects one of String types;", array.dataType());
|
||||||
|
|
||||||
uint64_t result = 0;
|
|
||||||
|
|
||||||
// our buffer stores offsets, and the last value is basically number of bytes used
|
|
||||||
auto buffer = array.bufferAsT<Nd4jLong>();
|
auto buffer = array.bufferAsT<Nd4jLong>();
|
||||||
result = buffer[array.lengthOf()];
|
return buffer[array.lengthOf()];
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::string> StringUtils::split(const std::string &haystack, const std::string &delimiter) {
|
std::vector<std::string> StringUtils::split(const std::string &haystack, const std::string &delimiter) {
|
||||||
|
@ -73,4 +70,89 @@ namespace nd4j {
|
||||||
|
|
||||||
return output;
|
return output;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool StringUtils::u8StringToU16String(const std::string& u8, std::u16string& u16) {
|
||||||
|
|
||||||
|
if (u8.empty())
|
||||||
|
return false;
|
||||||
|
|
||||||
|
u16.resize(unicode::offsetUtf8StringInUtf16(u8.data(), u8.size()) / sizeof(char16_t));
|
||||||
|
if (u8.size() == u16.size())
|
||||||
|
u16.assign(u8.begin(), u8.end());
|
||||||
|
else
|
||||||
|
return unicode::utf8to16(u8.data(), &u16[0], u8.size());
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool StringUtils::u8StringToU32String(const std::string& u8, std::u32string& u32) {
|
||||||
|
|
||||||
|
if (u8.empty())
|
||||||
|
return false;
|
||||||
|
|
||||||
|
u32.resize( unicode::offsetUtf8StringInUtf32(u8.data(), u8.size()) / sizeof(char32_t) );
|
||||||
|
if (u8.size() == u32.size())
|
||||||
|
u32.assign(u8.begin(), u8.end());
|
||||||
|
else
|
||||||
|
return unicode::utf8to32(u8.data(), &u32[0], u8.size());
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool StringUtils::u16StringToU32String(const std::u16string& u16, std::u32string& u32) {
|
||||||
|
|
||||||
|
if (u16.empty())
|
||||||
|
return false;
|
||||||
|
|
||||||
|
u32.resize(unicode::offsetUtf16StringInUtf32(u16.data(), u16.size()) / sizeof(char32_t));
|
||||||
|
if (u16.size() == u32.size())
|
||||||
|
u32.assign(u16.begin(), u16.end());
|
||||||
|
else
|
||||||
|
return unicode::utf16to32(u16.data(), &u32[0], u16.size());
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool StringUtils::u16StringToU8String(const std::u16string& u16, std::string& u8) {
|
||||||
|
|
||||||
|
if (u16.empty())
|
||||||
|
return false;
|
||||||
|
|
||||||
|
u8.resize(unicode::offsetUtf16StringInUtf8(u16.data(), u16.size()));
|
||||||
|
if (u16.size() == u8.size())
|
||||||
|
u8.assign(u16.begin(), u16.end());
|
||||||
|
else
|
||||||
|
return unicode::utf16to8(u16.data(), &u8[0], u16.size());
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool StringUtils::u32StringToU16String(const std::u32string& u32, std::u16string& u16) {
|
||||||
|
|
||||||
|
if (u32.empty())
|
||||||
|
return false;
|
||||||
|
|
||||||
|
u16.resize(unicode::offsetUtf32StringInUtf16(u32.data(), u32.size()) / sizeof(char16_t));
|
||||||
|
if (u32.size() == u16.size())
|
||||||
|
u16.assign(u32.begin(), u32.end());
|
||||||
|
else
|
||||||
|
return unicode::utf32to16(u32.data(), &u16[0], u32.size());
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool StringUtils::u32StringToU8String(const std::u32string& u32, std::string& u8) {
|
||||||
|
|
||||||
|
if (u32.empty())
|
||||||
|
return false;
|
||||||
|
|
||||||
|
u8.resize(unicode::offsetUtf32StringInUtf8(u32.data(), u32.size()));
|
||||||
|
if (u32.size() == u8.size())
|
||||||
|
u8.assign(u32.begin(), u32.end());
|
||||||
|
else
|
||||||
|
return unicode::utf32to8(u32.data(), &u8[0], u32.size());
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,456 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2020 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Oleg Semeniv <oleg.semeniv@gmail.com>
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <unicode.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace unicode {
|
||||||
|
|
||||||
|
constexpr uint32_t ONEBYTEBOUND = 0x00000080;
|
||||||
|
constexpr uint32_t TWOBYTEBOUND = 0x00000800;
|
||||||
|
constexpr uint32_t THREEBYTEBOUND = 0x00010000;
|
||||||
|
constexpr uint16_t HIGHBYTEMIN = 0xd800u;
|
||||||
|
constexpr uint16_t HIGHBYTEMAX = 0xdbffu;
|
||||||
|
constexpr uint16_t TRAILBYTEMIN = 0xdc00u;
|
||||||
|
constexpr uint16_t TRAILBYTEMAX = 0xdfffu;
|
||||||
|
constexpr uint16_t HIGHBYTEOFFSET = HIGHBYTEMIN - (0x10000 >> 10);
|
||||||
|
constexpr uint32_t BYTEOFFSET = 0x10000u - (HIGHBYTEMIN << 10) - TRAILBYTEMIN;
|
||||||
|
// Maximum valid value for a Unicode code point
|
||||||
|
constexpr uint32_t CODEPOINTMAX = 0x0010ffffu;
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
FORCEINLINE uint8_t castToU8(const T cp) {
|
||||||
|
return static_cast<uint8_t>(0xff & cp);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
FORCEINLINE uint16_t castToU16(const T cp) {
|
||||||
|
return static_cast<uint16_t>(0xffff & cp);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
FORCEINLINE uint32_t castToU32(const T cp) {
|
||||||
|
return static_cast<uint32_t>(0xffffff & cp);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
FORCEINLINE bool isTrail(const T cp) {
|
||||||
|
return ((castToU8(cp) >> 6) == 0x2);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
FORCEINLINE bool isHighSurrogate(const T cp) {
|
||||||
|
return (cp & 0xfffffc00) == 0xd800;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
bool isLowSurrogate(const T cp) {
|
||||||
|
return (cp & 0xfffffc00) == 0xdc00;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
FORCEINLINE bool isLeadSurrogate(const T cp) {
|
||||||
|
return (cp >= HIGHBYTEMIN && cp <= HIGHBYTEMAX);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
FORCEINLINE bool isTrailSurrogate(const T cp) {
|
||||||
|
return (cp >= TRAILBYTEMIN && cp <= TRAILBYTEMAX);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
FORCEINLINE bool isSurrogateU8(const T cp) {
|
||||||
|
return (cp >= HIGHBYTEMIN && cp <= TRAILBYTEMAX);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
FORCEINLINE bool isSurrogateU16(const T cp) {
|
||||||
|
return ((cp - 0xd800u) < 2048u);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
FORCEINLINE bool isSymbolU8Valid(const T cp) {
|
||||||
|
return (cp <= CODEPOINTMAX && !isSurrogateU8(cp));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
FORCEINLINE bool isSymbolValid(const T cp) {
|
||||||
|
return (cp <= CODEPOINTMAX);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
FORCEINLINE uint32_t surrogateU32(const T& high, const T& low) {
|
||||||
|
return (high << 10) + low - 0x35fdc00;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
Nd4jLong symbolLength(const T* it) {
|
||||||
|
uint8_t lead = castToU8(*it);
|
||||||
|
if (lead < 0x80)
|
||||||
|
return 1;
|
||||||
|
else if ((lead >> 5) == 0x6)
|
||||||
|
return 2;
|
||||||
|
else if ((lead >> 4) == 0xe)
|
||||||
|
return 3;
|
||||||
|
else if ((lead >> 3) == 0x1e)
|
||||||
|
return 4;
|
||||||
|
else
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
Nd4jLong symbolLength32(const T* it) {
|
||||||
|
auto lead = castToU32(*it);
|
||||||
|
if (lead < ONEBYTEBOUND)
|
||||||
|
return 1;
|
||||||
|
else if (lead < TWOBYTEBOUND)
|
||||||
|
return 2;
|
||||||
|
else if (lead < THREEBYTEBOUND)
|
||||||
|
return 3;
|
||||||
|
else if (lead <= CODEPOINTMAX)
|
||||||
|
return 4;
|
||||||
|
else
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
Nd4jLong symbolLength16(const T* it) {
|
||||||
|
|
||||||
|
uint32_t lead = castToU16(*it);
|
||||||
|
if (!isLeadSurrogate(lead)) {
|
||||||
|
if (lead < ONEBYTEBOUND)
|
||||||
|
return 1;
|
||||||
|
else if (lead < TWOBYTEBOUND)
|
||||||
|
return 2;
|
||||||
|
else if (lead < THREEBYTEBOUND)
|
||||||
|
return 3;
|
||||||
|
else
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
return 4;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Nd4jLong offsetUtf8StringInUtf32(const void* start, const void* end) {
|
||||||
|
|
||||||
|
Nd4jLong count = 0;
|
||||||
|
for (auto it = static_cast<const int8_t*>(start); it != end; it++) {
|
||||||
|
auto length = symbolLength(it);
|
||||||
|
it += (length > 0) ? (length - 1) : 0;
|
||||||
|
count += 1;
|
||||||
|
}
|
||||||
|
return static_cast<Nd4jLong>(count * sizeof(char32_t));
|
||||||
|
}
|
||||||
|
|
||||||
|
Nd4jLong offsetUtf16StringInUtf32(const void* start, const void* end) {
|
||||||
|
|
||||||
|
Nd4jLong count = 0;
|
||||||
|
for (auto it = static_cast<const uint16_t*>(start); it != end;) {
|
||||||
|
auto length = symbolLength16(it);
|
||||||
|
it += (4 == length) ? 2 : 1;
|
||||||
|
count += 1;
|
||||||
|
}
|
||||||
|
return static_cast<Nd4jLong>(count*sizeof(char32_t));
|
||||||
|
}
|
||||||
|
|
||||||
|
Nd4jLong offsetUtf8StringInUtf16(const void* start, const void* end) {
|
||||||
|
|
||||||
|
Nd4jLong count = 0;
|
||||||
|
for (auto it = static_cast<const int8_t*>(start); it != end; it++) {
|
||||||
|
auto length = symbolLength(it);
|
||||||
|
auto step = ((length > 0) ? (length - 1) : 0);
|
||||||
|
it += step;
|
||||||
|
count += (4 == length) ? 2 : 1;
|
||||||
|
}
|
||||||
|
return static_cast<Nd4jLong>(count*sizeof(char16_t));
|
||||||
|
}
|
||||||
|
|
||||||
|
Nd4jLong offsetUtf16StringInUtf8(const void* start, const void* end) {
|
||||||
|
|
||||||
|
Nd4jLong count = 0;
|
||||||
|
for (auto it = static_cast<const uint16_t*>(start); it != end;) {
|
||||||
|
auto length = symbolLength16(it);
|
||||||
|
it += (4 == length) ? 2 : 1;
|
||||||
|
count += length;
|
||||||
|
}
|
||||||
|
return static_cast<Nd4jLong>(count);
|
||||||
|
}
|
||||||
|
|
||||||
|
Nd4jLong offsetUtf32StringInUtf16(const void* start, const void* end) {
|
||||||
|
|
||||||
|
Nd4jLong count = 0;
|
||||||
|
for (auto it = static_cast<const uint32_t*>(start); it != end; it++) {
|
||||||
|
auto length = symbolLength32(it);
|
||||||
|
count += (4 == length) ? 2 : 1;;
|
||||||
|
}
|
||||||
|
return static_cast<Nd4jLong>(count*sizeof(char16_t));
|
||||||
|
}
|
||||||
|
|
||||||
|
Nd4jLong offsetUtf32StringInUtf8(const void* start, const void* end) {
|
||||||
|
|
||||||
|
Nd4jLong count = 0;
|
||||||
|
for (auto it = static_cast<const uint32_t*>(start); it != end; it++) {
|
||||||
|
count += symbolLength32(it);
|
||||||
|
}
|
||||||
|
return count;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool isStringValidU8(const void* start, const void* stop) {
|
||||||
|
for (auto it = static_cast<const int8_t*>(start); it != stop; it++) {
|
||||||
|
if (!isSymbolU8Valid( castToU8(*it) )) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool isStringValidU16(const void* start, const void* stop) {
|
||||||
|
for (auto it = static_cast<const uint16_t*>(start); it != stop; it++) {
|
||||||
|
if (!isSymbolValid( castToU32(*it) )) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool isStringValidU32(const void* start, const void* stop) {
|
||||||
|
for (auto it = static_cast<const uint32_t*>(start); it != stop; it++) {
|
||||||
|
if (!isSymbolValid( castToU32(*it) )) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void* utf16to8Ptr(const void* start, const void* end, void* res) {
|
||||||
|
|
||||||
|
auto result = static_cast<int8_t*>(res);
|
||||||
|
// result have to be pre-allocated
|
||||||
|
for (auto it = static_cast<const uint16_t*>(start); it != end;) {
|
||||||
|
uint32_t cp = castToU16(*it++);
|
||||||
|
if (!isLeadSurrogate(cp)) {
|
||||||
|
if (cp < 0x80) { // for one byte
|
||||||
|
*(result++) = static_cast<uint8_t>(cp);
|
||||||
|
}
|
||||||
|
else if (cp < 0x800) { // for two bytes
|
||||||
|
*(result++) = static_cast<uint8_t>((cp >> 6) | 0xc0);
|
||||||
|
*(result++) = static_cast<uint8_t>((cp & 0x3f) | 0x80);
|
||||||
|
}
|
||||||
|
else{ // for three bytes
|
||||||
|
*(result++) = static_cast<uint8_t>((cp >> 12) | 0xe0);
|
||||||
|
*(result++) = static_cast<uint8_t>(((cp >> 6) & 0x3f) | 0x80);
|
||||||
|
*(result++) = static_cast<uint8_t>((cp & 0x3f) | 0x80);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
if (it != end) {
|
||||||
|
uint32_t trail_surrogate = castToU16(*it++);
|
||||||
|
if (isTrailSurrogate(trail_surrogate))
|
||||||
|
cp = (cp << 10) + trail_surrogate + BYTEOFFSET;
|
||||||
|
}
|
||||||
|
// for four bytes
|
||||||
|
*(result++) = static_cast<uint8_t>((cp >> 18) | 0xf0);
|
||||||
|
*(result++) = static_cast<uint8_t>(((cp >> 12) & 0x3f) | 0x80);
|
||||||
|
*(result++) = static_cast<uint8_t>(((cp >> 6) & 0x3f) | 0x80);
|
||||||
|
*(result++) = static_cast<uint8_t>((cp & 0x3f) | 0x80);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
void* utf8to16Ptr(const void* start, const void* end, void* res) {
|
||||||
|
|
||||||
|
auto result = static_cast<uint16_t*>(res);
|
||||||
|
// result have to be pre-allocated
|
||||||
|
for (auto it = static_cast<const int8_t*>(start); it != end;) {
|
||||||
|
|
||||||
|
auto nLength = symbolLength(it);
|
||||||
|
uint32_t cp = castToU8(*it++);
|
||||||
|
if (4 != nLength) {
|
||||||
|
if (2 == nLength) {
|
||||||
|
cp = ((cp << 6) & 0x7ff) + ((*it++) & 0x3f);
|
||||||
|
}
|
||||||
|
else if (3 == nLength) {
|
||||||
|
cp = ((cp << 12) & 0xffff) + ((castToU8(*it++) << 6) & 0xfff);
|
||||||
|
cp += (*it++) & 0x3f;
|
||||||
|
}
|
||||||
|
*(result++) = static_cast<uint16_t>(cp);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
cp = ((cp << 18) & 0x1fffff) + ((castToU8(*it++) << 12) & 0x3ffff);
|
||||||
|
cp += (castToU8(*it++) << 6) & 0xfff;
|
||||||
|
cp += (*it++) & 0x3f;
|
||||||
|
//make a surrogate pair
|
||||||
|
*(result++) = static_cast<uint16_t>((cp >> 10) + HIGHBYTEOFFSET);
|
||||||
|
*(result++) = static_cast<uint16_t>((cp & 0x3ff) + TRAILBYTEMIN);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
void* utf32to8Ptr( const void* start, const void* end, void* result) {
|
||||||
|
|
||||||
|
auto res = static_cast<uint8_t*>(result);
|
||||||
|
// result have to be pre-allocated
|
||||||
|
for (auto it = static_cast<const uint32_t*>(start); it != end; it++) {
|
||||||
|
|
||||||
|
if (*it < 0x80) // for one byte
|
||||||
|
*(res++) = static_cast<uint8_t>(*it);
|
||||||
|
else if (*it < 0x800) { // for two bytes
|
||||||
|
*(res++) = static_cast<uint8_t>((*it >> 6) | 0xc0);
|
||||||
|
*(res++) = static_cast<uint8_t>((*it & 0x3f) | 0x80);
|
||||||
|
}
|
||||||
|
else if (*it < 0x10000) { // for three bytes
|
||||||
|
*(res++) = static_cast<uint8_t>((*it >> 12) | 0xe0);
|
||||||
|
*(res++) = static_cast<uint8_t>(((*it >> 6) & 0x3f) | 0x80);
|
||||||
|
*(res++) = static_cast<uint8_t>((*it & 0x3f) | 0x80);
|
||||||
|
}
|
||||||
|
else { // for four bytes
|
||||||
|
*(res++) = static_cast<uint8_t>((*it >> 18) | 0xf0);
|
||||||
|
*(res++) = static_cast<uint8_t>(((*it >> 12) & 0x3f) | 0x80);
|
||||||
|
*(res++) = static_cast<uint8_t>(((*it >> 6) & 0x3f) | 0x80);
|
||||||
|
*(res++) = static_cast<uint8_t>((*it & 0x3f) | 0x80);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
void* utf8to32Ptr(const void* start, const void* end, void* res) {
|
||||||
|
|
||||||
|
auto result = static_cast<uint32_t*>(res);
|
||||||
|
// result have to be pre-allocated
|
||||||
|
for (auto it = static_cast<const int8_t*>(start); it != end;) {
|
||||||
|
|
||||||
|
auto nLength = symbolLength(it);
|
||||||
|
uint32_t cp = castToU8(*it++);
|
||||||
|
if (2 == nLength) {
|
||||||
|
cp = ((cp << 6) & 0x7ff) + ((*it++) & 0x3f);
|
||||||
|
}
|
||||||
|
else if (3 == nLength) {
|
||||||
|
cp = ((cp << 12) & 0xffff) + ((castToU8(*it++) << 6) & 0xfff);
|
||||||
|
cp += (*it++) & 0x3f;
|
||||||
|
}
|
||||||
|
else if (4 == nLength) {
|
||||||
|
cp = ((cp << 18) & 0x1fffff) + ((castToU8(*it++) << 12) & 0x3ffff);
|
||||||
|
cp += (castToU8(*it++) << 6) & 0xfff;
|
||||||
|
cp += (*it++) & 0x3f;
|
||||||
|
}
|
||||||
|
(*result++) = cp;
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
void* utf16to32Ptr(const void* start, const void* end, void* res) {
|
||||||
|
|
||||||
|
auto result = static_cast<uint32_t*>(res);
|
||||||
|
// result have to be pre-allocated
|
||||||
|
for (auto it = static_cast<const uint16_t*>(start); it != end; it++) {
|
||||||
|
|
||||||
|
uint32_t cpHigh = castToU32(*it);
|
||||||
|
if (!isSurrogateU16(cpHigh)) {
|
||||||
|
*result++ = cpHigh;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
it++;
|
||||||
|
uint32_t cpLow = castToU32(*it);
|
||||||
|
if (isHighSurrogate(cpHigh) && it != end && isLowSurrogate(cpLow)) {
|
||||||
|
*result++ = surrogateU32(cpHigh, cpLow);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
void* utf32to16Ptr(const void* start, const void* end, void* res) {
|
||||||
|
|
||||||
|
auto result = static_cast<uint16_t*>(res);
|
||||||
|
// result have to be pre-allocate
|
||||||
|
for (auto it = static_cast<const uint32_t*>(start); it != end; it++) {
|
||||||
|
|
||||||
|
uint32_t cpHigh = castToU32(*it);
|
||||||
|
// todo check do we need this as we have pre-validation, if yes find out how to check u16
|
||||||
|
if (cpHigh < 0 || cpHigh > 0x10FFFF || (cpHigh >= 0xD800 && cpHigh <= 0xDFFF)) {
|
||||||
|
// Invalid code point. Replace with sentinel, per Unicode standard:
|
||||||
|
*result++ = u'\uFFFD';
|
||||||
|
}
|
||||||
|
else if (cpHigh < 0x10000UL) { // In the BMP.
|
||||||
|
*result++ = static_cast<char16_t>(cpHigh);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
*result++ = static_cast<char16_t>(((cpHigh - 0x10000UL) / 0x400U) + 0xD800U);
|
||||||
|
*result++ = static_cast<char16_t>(((cpHigh - 0x10000UL) % 0x400U) + 0xDC00U);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
Nd4jLong offsetUtf8StringInUtf32(const void* input, uint32_t nInputSize) {
|
||||||
|
return offsetUtf8StringInUtf32(input, static_cast<const int8_t*>(input) + nInputSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
Nd4jLong offsetUtf16StringInUtf32(const void* input, uint32_t nInputSize) {
|
||||||
|
return offsetUtf16StringInUtf32(input, static_cast<const uint16_t*>(input) + nInputSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
Nd4jLong offsetUtf8StringInUtf16(const void* input, uint32_t nInputSize) {
|
||||||
|
return offsetUtf8StringInUtf16(input, static_cast<const int8_t*>(input) + nInputSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
Nd4jLong offsetUtf16StringInUtf8(const void* input, uint32_t nInputSize) {
|
||||||
|
return offsetUtf16StringInUtf8(input, static_cast<const uint16_t*>(input) + nInputSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
Nd4jLong offsetUtf32StringInUtf8(const void* input, uint32_t nInputSize) {
|
||||||
|
return offsetUtf32StringInUtf8(input, static_cast<const uint32_t*>(input) + nInputSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
Nd4jLong offsetUtf32StringInUtf16(const void* input, const uint32_t nInputSize) {
|
||||||
|
return offsetUtf32StringInUtf16(input, static_cast<const uint32_t*>(input) + nInputSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool utf8to16(const void* input, void* output, uint32_t nInputSize) {
|
||||||
|
return utf8to16Ptr(input, static_cast<const int8_t*>(input) + nInputSize, output);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool utf8to32(const void* input, void* output, uint32_t nInputSize) {
|
||||||
|
return utf8to32Ptr(input, static_cast<const int8_t*>(input) + nInputSize, output);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool utf16to32(const void* input, void* output, uint32_t nInputSize) {
|
||||||
|
return utf16to32Ptr(input, static_cast<const uint16_t*>(input) + nInputSize, output);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool utf16to8(const void* input, void* output, uint32_t nInputSize) {
|
||||||
|
return utf16to8Ptr(input, static_cast<const uint16_t*>(input) + nInputSize, output);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool utf32to16(const void* input, void* output, uint32_t nInputSize) {
|
||||||
|
return utf32to16Ptr(input, static_cast<const uint32_t*>(input) + nInputSize, output);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool utf32to8(const void* input, void* output, const Nd4jLong nInputSize) {
|
||||||
|
return utf32to8Ptr(input, static_cast<const uint32_t*>(input) + nInputSize, output);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,189 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019-2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Oleg Semeniv <oleg.semeniv@gmail.com>
|
||||||
|
//
|
||||||
|
|
||||||
|
#ifndef LIBND4J_UNICODE_H
|
||||||
|
#define LIBND4J_UNICODE_H
|
||||||
|
|
||||||
|
#include <NDArray.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace unicode {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method calculate u16 offset based on utf8
|
||||||
|
* @param const pointer to the utf8 string start point
|
||||||
|
* @param size of the string
|
||||||
|
* @return offset of utf16
|
||||||
|
*/
|
||||||
|
Nd4jLong offsetUtf8StringInUtf16(const void* start, const void* end);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method calculate u8 offset based on utf16
|
||||||
|
* @param const pointer to the utf16 string start point
|
||||||
|
* @param size of the string
|
||||||
|
* @return offset of utf8
|
||||||
|
*/
|
||||||
|
Nd4jLong offsetUtf16StringInUtf8(const void* start, const void* end);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method calculate u32 offset based on utf16
|
||||||
|
* @param const pointer to the utf16 string start point
|
||||||
|
* @param size of the string
|
||||||
|
* @return offset of utf32
|
||||||
|
*/
|
||||||
|
Nd4jLong offsetUtf32StringInUtf16(const void* start, const void* end);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method calculate u32 offset based on utf8
|
||||||
|
* @param const pointer to the utf16 string start point
|
||||||
|
* @param size of the string
|
||||||
|
* @return offset of utf8
|
||||||
|
*/
|
||||||
|
Nd4jLong offsetUtf32StringInUtf8(const void* start, const void* end);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* This function check is valid charecter in u8 string
|
||||||
|
*/
|
||||||
|
bool isStringValidU8(const void* start, const void* stop);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* This function check is valid charecter in u16 string
|
||||||
|
*/
|
||||||
|
bool isStringValidU16(const void* start, const void* stop);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* This function check is valid u32 charecter in string
|
||||||
|
*/
|
||||||
|
bool isStringValidU32(const void* start, const void* stop);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method count offset for utf8 string in utf32
|
||||||
|
* @param const pointer to the utf8 string start point
|
||||||
|
* @param size of the string
|
||||||
|
* @return offset
|
||||||
|
*/
|
||||||
|
Nd4jLong offsetUtf8StringInUtf32(const void* input, uint32_t nInputSize);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method count offset for utf8 string in utf32
|
||||||
|
* @param const pointer to the utf8 string start point
|
||||||
|
* @param const end pointer to the utf8 string
|
||||||
|
* @return offset
|
||||||
|
*/
|
||||||
|
Nd4jLong offsetUtf8StringInUtf32(const void* input, const void* stop);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method count offset for utf32 based on utf16 string
|
||||||
|
* @param const pointer to the utf16 string start point
|
||||||
|
* @param size of the string
|
||||||
|
* @return offset
|
||||||
|
*/
|
||||||
|
Nd4jLong offsetUtf16StringInUtf32(const void* input, uint32_t nInputSize);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method calculate offset of u16 based on utf8
|
||||||
|
* @param const pointer to the utf8 string start point
|
||||||
|
* @param size of the string
|
||||||
|
* @return offset of utf16
|
||||||
|
*/
|
||||||
|
Nd4jLong offsetUtf8StringInUtf16(const void* input, uint32_t nInputSize);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method calculate offset of u8 based on utf16
|
||||||
|
* @param const pointer to the utf16 string start point
|
||||||
|
* @param size of the string
|
||||||
|
* @return offset of utf8
|
||||||
|
*/
|
||||||
|
Nd4jLong offsetUtf16StringInUtf8(const void* input, uint32_t nInputSize);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method calculate offset of u32 based on utf8
|
||||||
|
* @param const pointer to the utf16 string start point
|
||||||
|
* @param size of the string
|
||||||
|
* @return offset of utf32
|
||||||
|
*/
|
||||||
|
Nd4jLong offsetUtf32StringInUtf8(const void* input, uint32_t nInputSize);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method calculate offset of u32 based on utf16
|
||||||
|
* @param const pointer to the utf16 string start point
|
||||||
|
* @param size of the string
|
||||||
|
* @return offset of utf32
|
||||||
|
*/
|
||||||
|
Nd4jLong offsetUtf32StringInUtf16(const void* input, const uint32_t nInputSize);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method convert utf8 string to utf16 string
|
||||||
|
* @param const pointer to the utf8 string start point
|
||||||
|
* @param reference to start point to utf16
|
||||||
|
* @param size of input utf8 string
|
||||||
|
* @return status of convertion
|
||||||
|
*/
|
||||||
|
bool utf8to16(const void* input, void* output, uint32_t nInputSize);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method convert utf8 string to utf32 string
|
||||||
|
* @param const pointer to the utf8 string start point
|
||||||
|
* @param reference to start point to utf32
|
||||||
|
* @param size of input utf8 string
|
||||||
|
* @return status of convertion
|
||||||
|
*/
|
||||||
|
bool utf8to32(const void* input, void* output, uint32_t nInputSize);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method convert utf16 string to utf32 string
|
||||||
|
* @param const pointer to the utf16 string start point
|
||||||
|
* @param reference to start point to utf32
|
||||||
|
* @param size of input utf16 string
|
||||||
|
* @return status of convertion
|
||||||
|
*/
|
||||||
|
bool utf16to32(const void* input, void* output, uint32_t nInputSize);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method convert utf16 string to utf8 string
|
||||||
|
* @param const pointer to the utf16 string start point
|
||||||
|
* @param reference to start point to utf8
|
||||||
|
* @param size of input utf16 string
|
||||||
|
* @return status of convertion
|
||||||
|
*/
|
||||||
|
bool utf16to8(const void* input, void* output, uint32_t nInputSize);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method convert utf32 string to utf16 string
|
||||||
|
* @param const pointer to the utf32 string start point
|
||||||
|
* @param reference to start point to utf16
|
||||||
|
* @param size of input utf32 string
|
||||||
|
* @return status of convertion
|
||||||
|
*/
|
||||||
|
bool utf32to16(const void* input, void* output, uint32_t nInputSize);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method convert utf32 string to utf8 string
|
||||||
|
* @param const pointer to the utf32 string start point
|
||||||
|
* @param reference to start point to utf8
|
||||||
|
* @param size of input utf32 string
|
||||||
|
* @return status of convertion
|
||||||
|
*/
|
||||||
|
bool utf32to8(const void* input, void* output, const Nd4jLong nInputSize);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#endif //LIBND4J_UNICODE_H
|
|
@ -29,6 +29,7 @@ using namespace randomOps;
|
||||||
namespace functions {
|
namespace functions {
|
||||||
namespace random {
|
namespace random {
|
||||||
|
|
||||||
|
|
||||||
template<typename X>
|
template<typename X>
|
||||||
template<typename OpClass>
|
template<typename OpClass>
|
||||||
void RandomFunction<X>::execTransform(Nd4jPointer state,
|
void RandomFunction<X>::execTransform(Nd4jPointer state,
|
||||||
|
@ -56,6 +57,19 @@ namespace functions {
|
||||||
|
|
||||||
if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
|
if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
|
||||||
|
|
||||||
|
|
||||||
|
if(shape::elementWiseStride(zShapeInfo) == 1 && shape::elementWiseStride(xShapeInfo) == 1 && shape::elementWiseStride(yShapeInfo) == 1 &&
|
||||||
|
shape::order(xShapeInfo) == shape::order(zShapeInfo) && shape::order(zShapeInfo) == shape::order(yShapeInfo) ){
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
PRAGMA_OMP_SIMD
|
||||||
|
for (auto i = start; i < stop; i++) {
|
||||||
|
z[i] = OpClass::op(x[i], y[i], i, length, rng, extraArguments);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
samediff::Threads::parallel_for(func, 0, length, 1);
|
||||||
|
}
|
||||||
|
else{
|
||||||
uint xShapeInfoCast[MAX_RANK];
|
uint xShapeInfoCast[MAX_RANK];
|
||||||
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
||||||
|
|
||||||
|
@ -69,6 +83,7 @@ namespace functions {
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, length, 1);
|
samediff::Threads::parallel_for(func, 0, length, 1);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
else if (shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) {
|
else if (shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) {
|
||||||
|
|
||||||
uint xShapeInfoCast[MAX_RANK];
|
uint xShapeInfoCast[MAX_RANK];
|
||||||
|
@ -169,6 +184,17 @@ namespace functions {
|
||||||
|
|
||||||
if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
|
if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
|
||||||
|
|
||||||
|
if(shape::elementWiseStride(zShapeInfo) == 1 && shape::elementWiseStride(xShapeInfo) == 1 && shape::order(xShapeInfo) == shape::order(zShapeInfo)){
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
PRAGMA_OMP_SIMD
|
||||||
|
for (auto i = start; i < stop; i++) {
|
||||||
|
z[i] = OpClass::op(x[i], i, length, rng, extraArguments);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
samediff::Threads::parallel_for(func, 0, length, 1);
|
||||||
|
}
|
||||||
|
else{
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (uint64_t i = start; i < stop; i += increment) {
|
for (uint64_t i = start; i < stop; i += increment) {
|
||||||
|
@ -179,6 +205,7 @@ namespace functions {
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, length, 1);
|
samediff::Threads::parallel_for(func, 0, length, 1);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
else {
|
else {
|
||||||
|
|
||||||
uint zShapeInfoCast[MAX_RANK];
|
uint zShapeInfoCast[MAX_RANK];
|
||||||
|
@ -208,6 +235,19 @@ namespace functions {
|
||||||
auto length = shape::length(zShapeInfo);
|
auto length = shape::length(zShapeInfo);
|
||||||
|
|
||||||
nd4j::graph::RandomGenerator* rng = reinterpret_cast<nd4j::graph::RandomGenerator*>(state);
|
nd4j::graph::RandomGenerator* rng = reinterpret_cast<nd4j::graph::RandomGenerator*>(state);
|
||||||
|
|
||||||
|
if(shape::elementWiseStride(zShapeInfo) == 1){
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
PRAGMA_OMP_SIMD
|
||||||
|
for (auto i = start; i < stop; i++) {
|
||||||
|
z[i] = OpClass::op( i, length, rng, extraArguments);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0, length, 1);
|
||||||
|
}
|
||||||
|
else{
|
||||||
nd4j::OmpLaunchHelper info(length);
|
nd4j::OmpLaunchHelper info(length);
|
||||||
|
|
||||||
uint zShapeInfoCast[MAX_RANK];
|
uint zShapeInfoCast[MAX_RANK];
|
||||||
|
@ -223,6 +263,7 @@ namespace functions {
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, length, 1);
|
samediff::Threads::parallel_for(func, 0, length, 1);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template<typename X>
|
template<typename X>
|
||||||
void RandomFunction<X>::execTransform(int opNum, Nd4jPointer state, void *x, Nd4jLong *xShapeInfo, void *z, Nd4jLong *zShapeInfo, void *extraArguments) {
|
void RandomFunction<X>::execTransform(int opNum, Nd4jPointer state, void *x, Nd4jLong *xShapeInfo, void *z, Nd4jLong *zShapeInfo, void *extraArguments) {
|
||||||
|
|
|
@ -1516,7 +1516,9 @@
|
||||||
|
|
||||||
#define INPUT_LIST(INDEX) reinterpret_cast<nd4j::NDArrayList *>(block.getVariable(INDEX)->getNDArrayList())
|
#define INPUT_LIST(INDEX) reinterpret_cast<nd4j::NDArrayList *>(block.getVariable(INDEX)->getNDArrayList())
|
||||||
|
|
||||||
|
#define D_ARG(INDEX) block.getDArguments()->at(INDEX)
|
||||||
#define INT_ARG(INDEX) block.getIArguments()->at(INDEX)
|
#define INT_ARG(INDEX) block.getIArguments()->at(INDEX)
|
||||||
|
#define I_ARG(INDEX) INT_ARG(INDEX)
|
||||||
#define T_ARG(INDEX) block.getTArguments()->at(INDEX)
|
#define T_ARG(INDEX) block.getTArguments()->at(INDEX)
|
||||||
#define B_ARG(INDEX) block.getBArguments()->at(INDEX)
|
#define B_ARG(INDEX) block.getBArguments()->at(INDEX)
|
||||||
|
|
||||||
|
|
|
@ -36,9 +36,8 @@ namespace nd4j {
|
||||||
public:
|
public:
|
||||||
BooleanOp(const char *name, int numInputs, bool scalar);
|
BooleanOp(const char *name, int numInputs, bool scalar);
|
||||||
|
|
||||||
bool evaluate(std::initializer_list<nd4j::NDArray*> args);
|
bool verify(const std::vector<nd4j::NDArray*>& args);
|
||||||
bool evaluate(std::vector<nd4j::NDArray*>& args);
|
bool verify(nd4j::graph::Context& block);
|
||||||
bool evaluate(nd4j::graph::Context& block);
|
|
||||||
|
|
||||||
Nd4jStatus execute(Context* block) override;
|
Nd4jStatus execute(Context* block) override;
|
||||||
|
|
||||||
|
|
|
@ -169,13 +169,22 @@ namespace nd4j {
|
||||||
*/
|
*/
|
||||||
virtual Nd4jStatus execute(Context* block);
|
virtual Nd4jStatus execute(Context* block);
|
||||||
|
|
||||||
nd4j::ResultSet* execute(std::initializer_list<NDArray*> inputs, std::initializer_list<double> tArgs, std::initializer_list<Nd4jLong> iArgs, std::initializer_list<bool> bArgs, bool isInplace = false, nd4j::DataType type = nd4j::DataType::FLOAT32);
|
Nd4jStatus execute(const std::vector<NDArray*> &inputs, const std::vector<NDArray*> &outputs);
|
||||||
Nd4jStatus execute(std::initializer_list<NDArray*> inputs, std::initializer_list<NDArray*> outputs , std::initializer_list<double> tArgs, std::initializer_list<Nd4jLong> iArgs, std::initializer_list<bool> bArgs, bool isInplace = false, nd4j::DataType type = nd4j::DataType::FLOAT32);
|
|
||||||
Nd4jStatus execute(nd4j::graph::RandomGenerator& rng, std::initializer_list<NDArray*> inputs, std::initializer_list<NDArray*> outputs , std::initializer_list<double> tArgs, std::initializer_list<Nd4jLong> iArgs, std::initializer_list<bool> bArgs, bool isInplace = false, nd4j::DataType type = nd4j::DataType::FLOAT32);
|
|
||||||
|
|
||||||
nd4j::ResultSet* execute(const std::vector<NDArray*>& inputs, const std::vector<double>& tArgs, const std::vector<Nd4jLong>& iArgs, const std::vector<bool>& bArgs = std::vector<bool>(), bool isInplace = false, nd4j::DataType type = nd4j::DataType::FLOAT32);
|
template <class T, typename = std::enable_if<DataTypeUtils::scalarTypesForExecution<T>::value>>
|
||||||
Nd4jStatus execute(std::vector<NDArray*>& inputs, std::vector<NDArray*>& outputs , std::vector<double>& tArgs, std::vector<Nd4jLong>& iArgs, std::vector<bool>& bArgs, bool isInplace = false, nd4j::DataType type = nd4j::DataType::FLOAT32);
|
Nd4jStatus execute(const std::vector<NDArray*> &inputs, const std::vector<NDArray*> &outputs, std::initializer_list<T> tArgs);
|
||||||
Nd4jStatus execute(nd4j::graph::RandomGenerator& rng, std::vector<NDArray*>& inputs, std::vector<NDArray*>& outputs, std::vector<double>& tArgs, std::vector<Nd4jLong>& iArgs, std::vector<bool>& bArgs, bool isInplace = false, nd4j::DataType type = nd4j::DataType::FLOAT32);
|
|
||||||
|
Nd4jStatus execute(const std::vector<NDArray*> &inputs, const std::vector<NDArray*> &outputs, const std::vector<double> &tArgs, const std::vector<Nd4jLong> &iArgs, const std::vector<bool> &bArgs = std::vector<bool>(), const std::vector<nd4j::DataType> &dArgs = std::vector<nd4j::DataType>(), bool isInplace = false);
|
||||||
|
|
||||||
|
|
||||||
|
nd4j::ResultSet* evaluate(const std::vector<NDArray*> &inputs);
|
||||||
|
|
||||||
|
template <class T, typename = std::enable_if<DataTypeUtils::scalarTypesForExecution<T>::value>>
|
||||||
|
nd4j::ResultSet* evaluate(const std::vector<NDArray*> &inputs, std::initializer_list<T> args);
|
||||||
|
|
||||||
|
nd4j::ResultSet* evaluate(const std::vector<NDArray*> &inputs, const std::vector<double> &tArgs, const std::vector<Nd4jLong> &iArgs, const std::vector<bool> &bArgs = std::vector<bool>(), const std::vector<nd4j::DataType> &dArgs = std::vector<nd4j::DataType>(), bool isInplace = false);
|
||||||
|
|
||||||
|
Nd4jStatus execute(nd4j::graph::RandomGenerator& rng, const std::vector<NDArray*>& inputs, const std::vector<NDArray*>& outputs, const std::vector<double>& tArgs, const std::vector<Nd4jLong>& iArgs, const std::vector<bool>& bArgs, const std::vector<nd4j::DataType> &dArgs = std::vector<nd4j::DataType>(), bool isInplace = false, nd4j::DataType type = nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ResultSet* execute(const nd4j::OpArgsHolder& holder, bool isInplace = false);
|
nd4j::ResultSet* execute(const nd4j::OpArgsHolder& holder, bool isInplace = false);
|
||||||
|
|
||||||
|
|
|
@ -73,7 +73,7 @@ namespace nd4j {
|
||||||
|
|
||||||
// at first step we build fwd activation
|
// at first step we build fwd activation
|
||||||
nd4j::ops::crelu op;
|
nd4j::ops::crelu op;
|
||||||
auto tmpResult = op.execute({input}, {}, {}, {});
|
auto tmpResult = op.evaluate({input});
|
||||||
if (tmpResult->status() != ND4J_STATUS_OK)
|
if (tmpResult->status() != ND4J_STATUS_OK)
|
||||||
return tmpResult->status();
|
return tmpResult->status();
|
||||||
|
|
||||||
|
@ -84,7 +84,7 @@ namespace nd4j {
|
||||||
helpers::reluDerivative(block.launchContext(), actv, epsilonNext);
|
helpers::reluDerivative(block.launchContext(), actv, epsilonNext);
|
||||||
// now we split updated array into 2 chunks along last dimension
|
// now we split updated array into 2 chunks along last dimension
|
||||||
nd4j::ops::concat_bp opc;
|
nd4j::ops::concat_bp opc;
|
||||||
auto dec = opc.execute({input, input, actv}, {}, {-1}, {});
|
auto dec = opc.evaluate({input, input, actv}, {-1});
|
||||||
if (dec->status() != ND4J_STATUS_OK)
|
if (dec->status() != ND4J_STATUS_OK)
|
||||||
return dec->status();
|
return dec->status();
|
||||||
|
|
||||||
|
|
|
@ -103,7 +103,7 @@ namespace nd4j {
|
||||||
// if (output->isEmpty())
|
// if (output->isEmpty())
|
||||||
Nd4jLong width = condition->rankOf();
|
Nd4jLong width = condition->rankOf();
|
||||||
nd4j::ops::Where op;
|
nd4j::ops::Where op;
|
||||||
std::unique_ptr<ResultSet> res(op.execute({condition}, {}, {}, {}));
|
std::unique_ptr<ResultSet> res(op.evaluate({condition}));
|
||||||
REQUIRE_OK(res->status());
|
REQUIRE_OK(res->status());
|
||||||
NDArray* whereTrue = res->at(0);
|
NDArray* whereTrue = res->at(0);
|
||||||
if (whereTrue->isEmpty())
|
if (whereTrue->isEmpty())
|
||||||
|
|
|
@ -66,7 +66,7 @@ namespace nd4j {
|
||||||
auto gradY = OUTPUT_VARIABLE(1);
|
auto gradY = OUTPUT_VARIABLE(1);
|
||||||
gradX->assign(epsNext);
|
gradX->assign(epsNext);
|
||||||
nd4j::ops::floormod op;
|
nd4j::ops::floormod op;
|
||||||
std::unique_ptr<ResultSet> tmpResult(op.execute({x, y}, {}, {}, {}));
|
std::unique_ptr<ResultSet> tmpResult(op.evaluate({x, y}));
|
||||||
|
|
||||||
if (gradY->rankOf() == gradX->rankOf())
|
if (gradY->rankOf() == gradX->rankOf())
|
||||||
epsNext->applyPairwiseTransform(pairwise::Multiply, *tmpResult->at(0), *gradY);
|
epsNext->applyPairwiseTransform(pairwise::Multiply, *tmpResult->at(0), *gradY);
|
||||||
|
|
|
@ -118,7 +118,7 @@ namespace ops {
|
||||||
DECLARE_TYPES(Pow_bp) {
|
DECLARE_TYPES(Pow_bp) {
|
||||||
getOpDescriptor()
|
getOpDescriptor()
|
||||||
->setAllowedInputTypes({ ALL_FLOATS, ALL_INTS })
|
->setAllowedInputTypes({ ALL_FLOATS, ALL_INTS })
|
||||||
->setAllowedOutputTypes({ ALL_FLOATS }); // TODO maybe wourth to add ALL_INTS
|
->setAllowedOutputTypes({ ALL_FLOATS });
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -81,7 +81,7 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
|
|
||||||
// now once we have all strings in single vector time to fill
|
// now once we have all strings in single vector time to fill
|
||||||
auto tmp = NDArrayFactory::string('c', {(Nd4jLong) strings.size()}, strings);
|
auto tmp = NDArrayFactory::string({(Nd4jLong) strings.size()}, strings);
|
||||||
auto blen = StringUtils::byteLength(tmp) + ShapeUtils::stringBufferHeaderRequirements(strings.size());
|
auto blen = StringUtils::byteLength(tmp) + ShapeUtils::stringBufferHeaderRequirements(strings.size());
|
||||||
|
|
||||||
// for CUDA mostly
|
// for CUDA mostly
|
||||||
|
|
|
@ -197,8 +197,7 @@ CUSTOM_OP_IMPL(batchnorm_bp, 4, 3, false, 1, 2) {
|
||||||
// ***** calculations ***** //
|
// ***** calculations ***** //
|
||||||
|
|
||||||
// notations:
|
// notations:
|
||||||
// f = g * (gamma * ((x - m) / (v + eps)^0.5) + beta) -> means dLdO * ff_output
|
// f = g * (gamma * ((x - m) / (v + eps)^0.5) + beta) -> means dLdO * ff_output, g = dLdO
|
||||||
// g = dLdO
|
|
||||||
// stdInv = 1 / (v + eps)^0.5
|
// stdInv = 1 / (v + eps)^0.5
|
||||||
// N - batch size (product of spatial dimensions)
|
// N - batch size (product of spatial dimensions)
|
||||||
|
|
||||||
|
|
|
@ -222,7 +222,7 @@ CUSTOM_OP_IMPL(conv1d_bp, 3, 2, false, 0, 5) {
|
||||||
auto gradWReshaped = gradW ->reshape(gradW->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
|
auto gradWReshaped = gradW ->reshape(gradW->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
|
||||||
|
|
||||||
nd4j::ops::conv2d_bp conv2dBP;
|
nd4j::ops::conv2d_bp conv2dBP;
|
||||||
const Nd4jStatus status = conv2dBP.execute({&inputReshaped, &weightsReshaped, bias, &gradOReshaped}, {&gradIReshaped, &gradWReshaped, gradB}, {}, {1,kW, 1,sW, 0,pW, 1,dW, paddingMode, !isNCW}, {});
|
auto status = conv2dBP.execute({&inputReshaped, &weightsReshaped, bias, &gradOReshaped}, {&gradIReshaped, &gradWReshaped, gradB}, {}, {1,kW, 1,sW, 0,pW, 1,dW, paddingMode, !isNCW}, {});
|
||||||
if (status != ND4J_STATUS_OK)
|
if (status != ND4J_STATUS_OK)
|
||||||
return status;
|
return status;
|
||||||
|
|
||||||
|
|
|
@ -91,7 +91,7 @@ namespace ops {
|
||||||
}
|
}
|
||||||
|
|
||||||
nd4j::ops::softmax softmax;
|
nd4j::ops::softmax softmax;
|
||||||
softmax.execute({weights}, {weights}, {}, {-2}, {}, true);
|
softmax.execute({weights}, std::vector<NDArray*>{weights}, {}, {-2}, {}, {}, true);
|
||||||
|
|
||||||
mmul.execute({values, weights}, {output}, {}, {}, {});
|
mmul.execute({values, weights}, {output}, {}, {}, {});
|
||||||
|
|
||||||
|
@ -189,7 +189,7 @@ namespace ops {
|
||||||
|
|
||||||
nd4j::ops::matmul_bp mmul_bp;
|
nd4j::ops::matmul_bp mmul_bp;
|
||||||
NDArray dLdw(weights.getShapeInfo(), block.workspace());
|
NDArray dLdw(weights.getShapeInfo(), block.workspace());
|
||||||
mmul_bp.execute({values, &weights, eps}, {dLdv, &dLdw}, {}, {}, {});
|
mmul_bp.execute({values, &weights, eps}, std::vector<NDArray*>{dLdv, &dLdw}, {}, {}, {});
|
||||||
|
|
||||||
NDArray dLds(preSoftmax.shapeInfo(), block.workspace());
|
NDArray dLds(preSoftmax.shapeInfo(), block.workspace());
|
||||||
nd4j::ops::softmax_bp softmax_bp;
|
nd4j::ops::softmax_bp softmax_bp;
|
||||||
|
@ -198,7 +198,7 @@ namespace ops {
|
||||||
if(normalization)
|
if(normalization)
|
||||||
dLds /= factor;
|
dLds /= factor;
|
||||||
|
|
||||||
mmul_bp.execute({keys, queries, &dLds}, {dLdk, dLdq}, {}, {1}, {});
|
mmul_bp.execute({keys, queries, &dLds}, std::vector<NDArray*>{dLdk, dLdq}, {}, {1}, {});
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -239,7 +239,7 @@ namespace ops {
|
||||||
auto epsPostReshape = epsPerm.reshape(eps->ordering(), {miniBatchSize * queryCount, outSize});
|
auto epsPostReshape = epsPerm.reshape(eps->ordering(), {miniBatchSize * queryCount, outSize});
|
||||||
nd4j::ops::matmul_bp matmulBp;
|
nd4j::ops::matmul_bp matmulBp;
|
||||||
NDArray dLdPreWo(attnResults.shapeInfo(), false, block.launchContext());
|
NDArray dLdPreWo(attnResults.shapeInfo(), false, block.launchContext());
|
||||||
matmulBp.execute({&attnResults, Wo, &epsPostReshape}, {&dLdPreWo, dLdWo}, {}, {}, {});
|
matmulBp.execute({&attnResults, Wo, &epsPostReshape}, std::vector<NDArray*>{&dLdPreWo, dLdWo}, {}, {}, {});
|
||||||
|
|
||||||
// dLdAttn
|
// dLdAttn
|
||||||
dLdPreWo.reshapei({miniBatchSize, queryCount, numHeads, projectedValues.sizeAt(2)});
|
dLdPreWo.reshapei({miniBatchSize, queryCount, numHeads, projectedValues.sizeAt(2)});
|
||||||
|
|
|
@ -31,31 +31,28 @@ namespace ops {
|
||||||
CUSTOM_OP_IMPL(avgpool2d, 1, 1, false, 0, 10) {
|
CUSTOM_OP_IMPL(avgpool2d, 1, 1, false, 0, 10) {
|
||||||
|
|
||||||
auto input = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
REQUIRE_TRUE(input->rankOf() == 4, 0, "Input should have rank of 4, but got %i instead", input->rankOf());
|
|
||||||
|
|
||||||
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
|
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
|
||||||
auto argI = *(block.getIArguments());
|
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
|
||||||
|
|
||||||
const auto kH = INT_ARG(0);
|
const auto kH = INT_ARG(0);
|
||||||
const auto kW = INT_ARG(1);
|
const auto kW = INT_ARG(1);
|
||||||
const auto sH = INT_ARG(2);
|
const auto sH = INT_ARG(2);
|
||||||
const auto sW = INT_ARG(3);
|
const auto sW = INT_ARG(3);
|
||||||
int pH = INT_ARG(4);
|
auto pH = INT_ARG(4);
|
||||||
int pW = INT_ARG(5);
|
auto pW = INT_ARG(5);
|
||||||
const auto dH = INT_ARG(6);
|
const auto dH = INT_ARG(6);
|
||||||
const auto dW = INT_ARG(7);
|
const auto dW = INT_ARG(7);
|
||||||
const auto isSameMode = static_cast<bool>(INT_ARG(8));
|
const auto isSameMode = static_cast<bool>(INT_ARG(8));
|
||||||
const auto extraParam0 = INT_ARG(9);
|
const auto extraParam0 = INT_ARG(9);
|
||||||
|
const int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
|
||||||
|
|
||||||
|
REQUIRE_TRUE(input->rankOf() == 4, 0, "AVGPOOL2D op: input should have rank of 4, but got %i instead", input->rankOf());
|
||||||
REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D op: dilation must not be zero, but got instead {%i, %i}", dH, dW);
|
REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D op: dilation must not be zero, but got instead {%i, %i}", dH, dW);
|
||||||
|
|
||||||
int oH = 0;
|
int oH = 0;
|
||||||
int oW = 0;
|
int oW = 0;
|
||||||
|
|
||||||
int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
|
|
||||||
|
|
||||||
const int iH = static_cast<int>(isNCHW ? input->sizeAt(2) : input->sizeAt(1));
|
const int iH = static_cast<int>(isNCHW ? input->sizeAt(2) : input->sizeAt(1));
|
||||||
const int iW = static_cast<int>(isNCHW ? input->sizeAt(3) : input->sizeAt(2));
|
const int iW = static_cast<int>(isNCHW ? input->sizeAt(3) : input->sizeAt(2));
|
||||||
|
|
||||||
|
@ -207,7 +204,6 @@ CUSTOM_OP_IMPL(avgpool2d_bp, 2, 1, false, 0, 10) {
|
||||||
}
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
DECLARE_SHAPE_FN(avgpool2d_bp) {
|
DECLARE_SHAPE_FN(avgpool2d_bp) {
|
||||||
|
|
|
@ -51,14 +51,14 @@ CUSTOM_OP_IMPL(avgpool3dnew, 1, 1, false, 0, 14) {
|
||||||
int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC
|
int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC
|
||||||
|
|
||||||
REQUIRE_TRUE(input->rankOf() == 5, 0, "AVGPOOL3DNEW OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf());
|
REQUIRE_TRUE(input->rankOf() == 5, 0, "AVGPOOL3DNEW OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf());
|
||||||
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "AVGPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
|
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "AVGPOOL3DNEW OP: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
|
||||||
|
|
||||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||||
|
|
||||||
std::string expectedOutputShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}));
|
std::string expectedOutputShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}));
|
||||||
REQUIRE_TRUE(expectedOutputShape == ShapeUtils::shapeAsString(output), 0, "AVGPOOL3D op: wrong shape of output array, expected is %s, but got %s instead !", expectedOutputShape.c_str(), ShapeUtils::shapeAsString(output).c_str());
|
REQUIRE_TRUE(expectedOutputShape == ShapeUtils::shapeAsString(output), 0, "AVGPOOL3DNEW OP: wrong shape of output array, expected is %s, but got %s instead !", expectedOutputShape.c_str(), ShapeUtils::shapeAsString(output).c_str());
|
||||||
|
|
||||||
if(!isNCDHW) {
|
if(!isNCDHW) {
|
||||||
input = new NDArray(input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
input = new NDArray(input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
||||||
|
@ -176,8 +176,8 @@ CUSTOM_OP_IMPL(avgpool3dnew_bp, 2, 1, false, 0, 14) {
|
||||||
|
|
||||||
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}));
|
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}));
|
||||||
std::string expectedGradIShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iD,iH,iW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}));
|
std::string expectedGradIShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iD,iH,iW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}));
|
||||||
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "AVGPOOL3D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "AVGPOOL3DNEW_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||||
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, "AVGPOOL3D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
|
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, "AVGPOOL3DNEW_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
|
||||||
|
|
||||||
if(!isNCDHW) {
|
if(!isNCDHW) {
|
||||||
input = new NDArray(input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
input = new NDArray(input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
||||||
|
|
|
@ -32,6 +32,7 @@ namespace ops {
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
// maxpool2d corresponds to poolingMode=0
|
// maxpool2d corresponds to poolingMode=0
|
||||||
CUSTOM_OP_IMPL(maxpool2d, 1, 1, false, 0, 9) {
|
CUSTOM_OP_IMPL(maxpool2d, 1, 1, false, 0, 9) {
|
||||||
|
|
||||||
auto input = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
|
||||||
REQUIRE_TRUE(input->rankOf() == 4, 0, "MAXPOOL2D OP: input array should have rank of 4, but got %i instead", input->rankOf());
|
REQUIRE_TRUE(input->rankOf() == 4, 0, "MAXPOOL2D OP: input array should have rank of 4, but got %i instead", input->rankOf());
|
||||||
|
|
|
@ -40,7 +40,7 @@ namespace nd4j {
|
||||||
//nd4j_printf("Matrix x(%ix%i), Matrix w(%ix%i), b(1x%i)\n", x->sizeAt(0), x->sizeAt(1), w->sizeAt(0), w->sizeAt(1), b->lengthOf());
|
//nd4j_printf("Matrix x(%ix%i), Matrix w(%ix%i), b(1x%i)\n", x->sizeAt(0), x->sizeAt(1), w->sizeAt(0), w->sizeAt(1), b->lengthOf());
|
||||||
|
|
||||||
nd4j::ops::xw_plus_b op;
|
nd4j::ops::xw_plus_b op;
|
||||||
std::unique_ptr<ResultSet> result(op.execute({x, w, b}, {}, {}, {}));
|
std::unique_ptr<ResultSet> result(op.evaluate({x, w, b}));
|
||||||
REQUIRE_TRUE(Status::OK() == result->status(), 0, "relu_layer: xw_plus_b op failed on input data.");
|
REQUIRE_TRUE(Status::OK() == result->status(), 0, "relu_layer: xw_plus_b op failed on input data.");
|
||||||
|
|
||||||
auto scalar = block.numT() > 0 ? block.getTArguments()->at(0) : 0.0;
|
auto scalar = block.numT() > 0 ? block.getTArguments()->at(0) : 0.0;
|
||||||
|
|
|
@ -34,7 +34,7 @@ namespace nd4j {
|
||||||
|
|
||||||
auto tZ = BroadcastHelper::broadcastApply(BROADCAST_BOOL(GreaterThan), x, y, &z0);
|
auto tZ = BroadcastHelper::broadcastApply(BROADCAST_BOOL(GreaterThan), x, y, &z0);
|
||||||
bitcast res;
|
bitcast res;
|
||||||
auto status = res.execute({tZ}, {z}, {}, {DataType::UINT8}, {}, false);
|
auto status = res.execute({tZ}, {z}, {}, {DataType::UINT8}, {}, {}, false);
|
||||||
if (tZ != &z0) {
|
if (tZ != &z0) {
|
||||||
delete tZ;
|
delete tZ;
|
||||||
}
|
}
|
||||||
|
|
|
@ -112,7 +112,7 @@ namespace ops {
|
||||||
NDArray originalIndices(*indices); //->ordering(), indices->shapeInfo(), indices->dataType());
|
NDArray originalIndices(*indices); //->ordering(), indices->shapeInfo(), indices->dataType());
|
||||||
originalIndices.linspace(0);
|
originalIndices.linspace(0);
|
||||||
ops::dynamic_partition op;
|
ops::dynamic_partition op;
|
||||||
auto res = op.execute({&originalIndices, indices}, {}, {numPartition});
|
auto res = op.evaluate({&originalIndices, indices}, {numPartition});
|
||||||
REQUIRE_TRUE(res->status() == ND4J_STATUS_OK, 0, "dynamic_partition_bp: Error with dynamic partitioning.");
|
REQUIRE_TRUE(res->status() == ND4J_STATUS_OK, 0, "dynamic_partition_bp: Error with dynamic partitioning.");
|
||||||
ops::dynamic_stitch stichOp;
|
ops::dynamic_stitch stichOp;
|
||||||
std::vector<NDArray*> partitions(numPartition * 2);
|
std::vector<NDArray*> partitions(numPartition * 2);
|
||||||
|
@ -121,7 +121,7 @@ namespace ops {
|
||||||
partitions[i + numPartition] = gradOutList[i];
|
partitions[i + numPartition] = gradOutList[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
auto result = stichOp.execute(partitions, {}, {numPartition}, {}, false);
|
auto result = stichOp.evaluate(partitions, {numPartition});
|
||||||
REQUIRE_TRUE(result->status() == ND4J_STATUS_OK, 0, "dynamic_partition_bp: Error with dynamic partitioning.");
|
REQUIRE_TRUE(result->status() == ND4J_STATUS_OK, 0, "dynamic_partition_bp: Error with dynamic partitioning.");
|
||||||
result->at(0)->reshapei(outputList[0]->getShapeAsVector());
|
result->at(0)->reshapei(outputList[0]->getShapeAsVector());
|
||||||
outputList[1]->assign(indices);
|
outputList[1]->assign(indices);
|
||||||
|
|
|
@ -66,7 +66,7 @@ CUSTOM_OP_IMPL(embedding_lookup, 2, 1, false, 0, 1) {
|
||||||
|
|
||||||
nd4j::ops::gather op;
|
nd4j::ops::gather op;
|
||||||
|
|
||||||
std::unique_ptr<ResultSet> result(op.execute({input, indeces}, {}, {0}, {}));
|
std::unique_ptr<ResultSet> result(op.evaluate({input, indeces}, {0}));
|
||||||
REQUIRE_TRUE(result->status() == Status::OK(), 0, "embedding_lookup: cannot retrieve results from gather op.");
|
REQUIRE_TRUE(result->status() == Status::OK(), 0, "embedding_lookup: cannot retrieve results from gather op.");
|
||||||
REQUIRE_TRUE(result->at(0)->isSameShape(output), 0, "embedding_lookup: wrong shape of return from gather op.");
|
REQUIRE_TRUE(result->at(0)->isSameShape(output), 0, "embedding_lookup: wrong shape of return from gather op.");
|
||||||
output->assign(result->at(0));
|
output->assign(result->at(0));
|
||||||
|
@ -94,7 +94,7 @@ DECLARE_SHAPE_FN(embedding_lookup) {
|
||||||
for (int e = 1; e < outRank; e++)
|
for (int e = 1; e < outRank; e++)
|
||||||
shapeInfo[e] = shape::sizeAt(inShapeInfo, e);
|
shapeInfo[e] = shape::sizeAt(inShapeInfo, e);
|
||||||
|
|
||||||
auto outShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(block.dataType(), shape::order(inShapeInfo), shapeInfo);
|
auto outShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShapeInfo), shape::order(inShapeInfo), shapeInfo);
|
||||||
return SHAPELIST(outShapeInfo);
|
return SHAPELIST(outShapeInfo);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -32,6 +32,11 @@ namespace ops {
|
||||||
auto finish = INPUT_VARIABLE(1);
|
auto finish = INPUT_VARIABLE(1);
|
||||||
auto numOfElements = INPUT_VARIABLE(2);
|
auto numOfElements = INPUT_VARIABLE(2);
|
||||||
|
|
||||||
|
if (numOfElements->e<Nd4jLong>(0) == 1) {
|
||||||
|
output->assign(start);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
output->linspace(start->e<double>(0), (finish->e<double>(0) - start->e<double>(0)) / (numOfElements->e<Nd4jLong>(0) - 1.));
|
output->linspace(start->e<double>(0), (finish->e<double>(0) - start->e<double>(0)) / (numOfElements->e<Nd4jLong>(0) - 1.));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -74,6 +74,8 @@ namespace nd4j {
|
||||||
DECLARE_SHAPE_FN(onehot) {
|
DECLARE_SHAPE_FN(onehot) {
|
||||||
auto inShape = inputShape->at(0);
|
auto inShape = inputShape->at(0);
|
||||||
|
|
||||||
|
nd4j::DataType dtype = block.numD() > 0 ? D_ARG(0) : nd4j::DataType::FLOAT32;
|
||||||
|
|
||||||
int depth = -1;
|
int depth = -1;
|
||||||
Nd4jLong axis = -1;
|
Nd4jLong axis = -1;
|
||||||
|
|
||||||
|
@ -99,7 +101,7 @@ namespace nd4j {
|
||||||
shape.push_back(shape::shapeOf(inShape)[e]);
|
shape.push_back(shape::shapeOf(inShape)[e]);
|
||||||
|
|
||||||
shape.insert(shape.begin() + axis, depth);
|
shape.insert(shape.begin() + axis, depth);
|
||||||
newShape = ConstantShapeHelper::getInstance()->createShapeInfo(block.dataType(), 'c', rank + 1, shape.data());
|
newShape = ConstantShapeHelper::getInstance()->createShapeInfo(dtype, 'c', rank + 1, shape.data());
|
||||||
|
|
||||||
return SHAPELIST(newShape);
|
return SHAPELIST(newShape);
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,7 +25,7 @@
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
OP_IMPL(ones_as, 1, 1, false) {
|
CUSTOM_OP_IMPL(ones_as, 1, 1, false, 0, 0) {
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
output->assign(1);
|
output->assign(1);
|
||||||
|
@ -33,11 +33,21 @@ namespace nd4j {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
DECLARE_SHAPE_FN(ones_as) {
|
||||||
|
auto in = inputShape->at(0);
|
||||||
|
auto dtype = block.numD() ? D_ARG(0) : ArrayOptions::dataType(in);
|
||||||
|
auto shape = nd4j::ConstantShapeHelper::getInstance()->createShapeInfo(dtype, in);
|
||||||
|
|
||||||
|
nd4j_printf("numD: %i; dtype: %s\n", block.numD(), DataTypeUtils::asString(dtype).c_str());
|
||||||
|
|
||||||
|
return SHAPELIST(shape);
|
||||||
|
}
|
||||||
|
|
||||||
DECLARE_TYPES(ones_as) {
|
DECLARE_TYPES(ones_as) {
|
||||||
getOpDescriptor()
|
getOpDescriptor()
|
||||||
->setAllowedInputTypes(nd4j::DataType::ANY)
|
->setAllowedInputTypes(nd4j::DataType::ANY)
|
||||||
->setAllowedOutputTypes(nd4j::DataType::ANY)
|
->setAllowedOutputTypes(nd4j::DataType::ANY)
|
||||||
->setSameMode(true);
|
->setSameMode(false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -130,7 +130,7 @@ DECLARE_SHAPE_FN(range) {
|
||||||
const int numIArgs = block.getIArguments()->size();
|
const int numIArgs = block.getIArguments()->size();
|
||||||
|
|
||||||
Nd4jLong steps = 0;
|
Nd4jLong steps = 0;
|
||||||
nd4j::DataType dataType = nd4j::DataType::INHERIT;
|
nd4j::DataType dataType = block.numD() ? D_ARG(0) : nd4j::DataType::INHERIT;
|
||||||
|
|
||||||
if (numInArrs > 0) {
|
if (numInArrs > 0) {
|
||||||
auto isR = INPUT_VARIABLE(0)->isR();
|
auto isR = INPUT_VARIABLE(0)->isR();
|
||||||
|
@ -159,6 +159,8 @@ DECLARE_SHAPE_FN(range) {
|
||||||
REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !");
|
REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !");
|
||||||
|
|
||||||
steps = static_cast<Nd4jLong >((limit - start) / delta);
|
steps = static_cast<Nd4jLong >((limit - start) / delta);
|
||||||
|
|
||||||
|
if (!block.numD())
|
||||||
dataType = INPUT_VARIABLE(0)->dataType();
|
dataType = INPUT_VARIABLE(0)->dataType();
|
||||||
|
|
||||||
if(math::nd4j_abs<double>(start + steps * delta) < math::nd4j_abs<double >(limit))
|
if(math::nd4j_abs<double>(start + steps * delta) < math::nd4j_abs<double >(limit))
|
||||||
|
@ -187,6 +189,8 @@ DECLARE_SHAPE_FN(range) {
|
||||||
REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !");
|
REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !");
|
||||||
|
|
||||||
steps = static_cast<Nd4jLong >((limit - start) / delta);
|
steps = static_cast<Nd4jLong >((limit - start) / delta);
|
||||||
|
|
||||||
|
if (!block.numD())
|
||||||
dataType = INPUT_VARIABLE(0)->dataType();
|
dataType = INPUT_VARIABLE(0)->dataType();
|
||||||
|
|
||||||
if(math::nd4j_abs<double>(start + steps * delta) < math::nd4j_abs<double >(limit))
|
if(math::nd4j_abs<double>(start + steps * delta) < math::nd4j_abs<double >(limit))
|
||||||
|
@ -214,10 +218,12 @@ DECLARE_SHAPE_FN(range) {
|
||||||
|
|
||||||
REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !");
|
REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !");
|
||||||
|
|
||||||
|
if (!block.numD()) {
|
||||||
if (limit > DataTypeUtils::max<int>())
|
if (limit > DataTypeUtils::max<int>())
|
||||||
dataType = nd4j::DataType::INT64;
|
dataType = nd4j::DataType::INT64;
|
||||||
else
|
else
|
||||||
dataType = nd4j::DataType::INT32;
|
dataType = nd4j::DataType::INT32;
|
||||||
|
}
|
||||||
|
|
||||||
steps = (limit - start) / delta;
|
steps = (limit - start) / delta;
|
||||||
|
|
||||||
|
@ -248,10 +254,13 @@ DECLARE_SHAPE_FN(range) {
|
||||||
REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !");
|
REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !");
|
||||||
|
|
||||||
steps = static_cast<Nd4jLong >((limit - start) / delta);
|
steps = static_cast<Nd4jLong >((limit - start) / delta);
|
||||||
|
|
||||||
|
if (!block.numD()) {
|
||||||
if (Environment::getInstance()->precisionBoostAllowed())
|
if (Environment::getInstance()->precisionBoostAllowed())
|
||||||
dataType = nd4j::DataType::DOUBLE;
|
dataType = nd4j::DataType::DOUBLE;
|
||||||
else
|
else
|
||||||
dataType = Environment::getInstance()->defaultFloatDataType();
|
dataType = Environment::getInstance()->defaultFloatDataType();
|
||||||
|
}
|
||||||
|
|
||||||
if(math::nd4j_abs<double>(start + steps * delta) < math::nd4j_abs<double >(limit))
|
if(math::nd4j_abs<double>(start + steps * delta) < math::nd4j_abs<double >(limit))
|
||||||
++steps;
|
++steps;
|
||||||
|
|
|
@ -0,0 +1,75 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit, K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// Created by GS <sgazeos@gmail.com> at 01/22/2020
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <op_boilerplate.h>
|
||||||
|
#if NOT_EXCLUDED(OP_solve)
|
||||||
|
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
#include <ops/declarable/helpers/solve.h>
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
CUSTOM_OP_IMPL(solve, 2, 1, false, 0, 0) {
|
||||||
|
auto a = INPUT_VARIABLE(0);
|
||||||
|
auto b = INPUT_VARIABLE(1);
|
||||||
|
auto z = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
bool useAdjoint = false;
|
||||||
|
|
||||||
|
if (block.numB() > 0) {
|
||||||
|
useAdjoint = B_ARG(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
REQUIRE_TRUE(a->rankOf() >=2, 0, "solve: The rank of input left tensor should not be less than 2, but %i is given", a->rankOf());
|
||||||
|
REQUIRE_TRUE(b->rankOf() >=2, 0, "solve: The rank of input right tensor should not be less than 2, but %i is given", b->rankOf());
|
||||||
|
|
||||||
|
REQUIRE_TRUE(a->sizeAt(-1) == a->sizeAt(-2), 0, "solve: The last two dimmensions should be equal, but %i and %i are given", a->sizeAt(-1), a->sizeAt(-2));
|
||||||
|
REQUIRE_TRUE(a->sizeAt(-1) == b->sizeAt(-2), 0, "solve: The last dimmension of left part should be equal to prelast of right part, but %i and %i are given", a->sizeAt(-1), b->sizeAt(-2));
|
||||||
|
auto input = a;
|
||||||
|
if (useAdjoint) {
|
||||||
|
auto adjointA = a->ulike();
|
||||||
|
helpers::adjointMatrix(block.launchContext(), a, &adjointA);
|
||||||
|
input = new NDArray(adjointA); //.detach();
|
||||||
|
};
|
||||||
|
|
||||||
|
auto res = helpers::solveFunctor(block.launchContext(), input, b, useAdjoint, z);
|
||||||
|
if (input != a)
|
||||||
|
delete input;
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
DECLARE_SHAPE_FN(solve) {
|
||||||
|
auto in0 = inputShape->at(1);
|
||||||
|
auto in1 = inputShape->at(1);
|
||||||
|
auto luShape = ShapeBuilders::copyShapeInfoAndType(in1, in0, true, block.workspace());
|
||||||
|
|
||||||
|
return SHAPELIST(CONSTANT(luShape));
|
||||||
|
}
|
||||||
|
|
||||||
|
DECLARE_TYPES(solve) {
|
||||||
|
getOpDescriptor()
|
||||||
|
->setAllowedInputTypes({ALL_FLOATS})
|
||||||
|
->setAllowedOutputTypes({ALL_FLOATS})
|
||||||
|
->setSameMode(false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
|
@ -25,7 +25,7 @@
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
OP_IMPL(zeros_as, 1, 1, false) {
|
CUSTOM_OP_IMPL(zeros_as, 1, 1, false, 0, 0) {
|
||||||
auto out = OUTPUT_VARIABLE(0);
|
auto out = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
out->assign(0); // output is filled by zero by default
|
out->assign(0); // output is filled by zero by default
|
||||||
|
@ -35,11 +35,20 @@ namespace nd4j {
|
||||||
DECLARE_SYN(zeroslike, zeros_as);
|
DECLARE_SYN(zeroslike, zeros_as);
|
||||||
DECLARE_SYN(zeros_like, zeros_as);
|
DECLARE_SYN(zeros_like, zeros_as);
|
||||||
|
|
||||||
|
|
||||||
|
DECLARE_SHAPE_FN(zeros_as) {
|
||||||
|
auto in = inputShape->at(0);
|
||||||
|
auto dtype = block.numD() ? D_ARG(0) : ArrayOptions::dataType(in);
|
||||||
|
auto shape = nd4j::ConstantShapeHelper::getInstance()->createShapeInfo(dtype, in);
|
||||||
|
|
||||||
|
return SHAPELIST(shape);
|
||||||
|
}
|
||||||
|
|
||||||
DECLARE_TYPES(zeros_as) {
|
DECLARE_TYPES(zeros_as) {
|
||||||
getOpDescriptor()
|
getOpDescriptor()
|
||||||
->setAllowedInputTypes(nd4j::DataType::ANY)
|
->setAllowedInputTypes(nd4j::DataType::ANY)
|
||||||
->setAllowedOutputTypes(nd4j::DataType::ANY)
|
->setAllowedOutputTypes(nd4j::DataType::ANY)
|
||||||
->setSameMode(true);
|
->setSameMode(false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -84,7 +84,7 @@ CUSTOM_OP_IMPL(dynamic_bidirectional_rnn, 7, 4, false, 0, 0) {
|
||||||
|
|
||||||
// forward steps
|
// forward steps
|
||||||
nd4j::ops::dynamic_rnn dynamicRnn;
|
nd4j::ops::dynamic_rnn dynamicRnn;
|
||||||
auto resultsFW = dynamicRnn.execute({x, WxFW, WhFW, bFW, h0FW, maxTimeStep}, {}, {timeMajor}, {}, false, x->dataType());
|
auto resultsFW = dynamicRnn.evaluate({x, WxFW, WhFW, bFW, h0FW, maxTimeStep}, {timeMajor});
|
||||||
hFW->assign(resultsFW->at(0)); // [time x bS x numUnitsFW] or [bS x time x numUnitsFW]
|
hFW->assign(resultsFW->at(0)); // [time x bS x numUnitsFW] or [bS x time x numUnitsFW]
|
||||||
hFWFinal->assign(resultsFW->at(1));
|
hFWFinal->assign(resultsFW->at(1));
|
||||||
|
|
||||||
|
@ -97,17 +97,17 @@ CUSTOM_OP_IMPL(dynamic_bidirectional_rnn, 7, 4, false, 0, 0) {
|
||||||
|
|
||||||
// reverse x
|
// reverse x
|
||||||
nd4j::ops::reverse_sequence reverse;
|
nd4j::ops::reverse_sequence reverse;
|
||||||
auto resultsIn = timeMajor ? reverse.execute({x, seqLen}, {}, {0, 1}, {}, false, x->dataType()) : reverse.execute({x, seqLen}, {}, {1, 0}, {}, false, x->dataType());
|
auto resultsIn = timeMajor ? reverse.evaluate({x, seqLen}, {0, 1}) : reverse.evaluate({x, seqLen}, {1, 0});
|
||||||
REQUIRE_TRUE (resultsIn->status() == ND4J_STATUS_OK, 0, "dynamic_bidirectional_rnn: there is a problem with reverse on the sequence.");
|
REQUIRE_TRUE (resultsIn->status() == ND4J_STATUS_OK, 0, "dynamic_bidirectional_rnn: there is a problem with reverse on the sequence.");
|
||||||
auto revInput = resultsIn->at(0);
|
auto revInput = resultsIn->at(0);
|
||||||
|
|
||||||
// backward steps
|
// backward steps
|
||||||
auto resultsBW = dynamicRnn.execute({revInput, WxBW, WhBW, bBW, h0BW, maxTimeStep}, {}, {timeMajor}, {});
|
auto resultsBW = dynamicRnn.evaluate({revInput, WxBW, WhBW, bBW, h0BW, maxTimeStep}, {timeMajor});
|
||||||
auto hBWtemp = resultsBW->at(0); // [time x bS x numUnitsBW] or [ bS x time xnumUnitsBW]
|
auto hBWtemp = resultsBW->at(0); // [time x bS x numUnitsBW] or [ bS x time xnumUnitsBW]
|
||||||
hBWFinal->assign(resultsBW->at(1));
|
hBWFinal->assign(resultsBW->at(1));
|
||||||
|
|
||||||
// reverse hBWtemp
|
// reverse hBWtemp
|
||||||
auto resultsOut = timeMajor ? reverse.execute({hBWtemp, seqLen}, {}, {0, 1}, {}) : reverse.execute({hBWtemp, seqLen}, {}, {1, 0}, {});
|
auto resultsOut = timeMajor ? reverse.evaluate({hBWtemp, seqLen}, {0, 1}) : reverse.evaluate({hBWtemp, seqLen}, {1, 0});
|
||||||
hBW->assign(resultsOut->at(0));
|
hBW->assign(resultsOut->at(0));
|
||||||
|
|
||||||
delete resultsOut;
|
delete resultsOut;
|
||||||
|
|
|
@ -48,7 +48,7 @@ namespace ops {
|
||||||
|
|
||||||
auto conv = ArrayUtils::toLongVector(*block.getIArguments());
|
auto conv = ArrayUtils::toLongVector(*block.getIArguments());
|
||||||
|
|
||||||
auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(block.dataType(), shape::order(in), conv);
|
auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(in), shape::order(in), conv);
|
||||||
|
|
||||||
return SHAPELIST(newShape);
|
return SHAPELIST(newShape);
|
||||||
}
|
}
|
||||||
|
|
|
@ -487,7 +487,7 @@ namespace nd4j {
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
#if NOT_EXCLUDED(OP_zeros_as)
|
#if NOT_EXCLUDED(OP_zeros_as)
|
||||||
DECLARE_OP(zeros_as, 1, 1, false);
|
DECLARE_CUSTOM_OP(zeros_as, 1, 1, false, 0, 0);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -497,7 +497,7 @@ namespace nd4j {
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
#if NOT_EXCLUDED(OP_ones_as)
|
#if NOT_EXCLUDED(OP_ones_as)
|
||||||
DECLARE_OP(ones_as, 1, 1, false);
|
DECLARE_CUSTOM_OP(ones_as, 1, 1, false, 0, 0);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -1076,6 +1076,24 @@ namespace nd4j {
|
||||||
DECLARE_CUSTOM_OP(triangular_solve, 2, 1, true, 0, 0);
|
DECLARE_CUSTOM_OP(triangular_solve, 2, 1, true, 0, 0);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
/**
|
||||||
|
* solve op. - solve systems of linear equations - general method.
|
||||||
|
*
|
||||||
|
* input params:
|
||||||
|
* 0 - the tensor with dimension (x * y * z * ::: * M * M) - left parts of equations
|
||||||
|
* 1 - the tensor with dimension (x * y * z * ::: * M * K) - right parts of equations
|
||||||
|
*
|
||||||
|
* boolean args:
|
||||||
|
* 0 - adjoint - default is false (optional) - indicate input matrix or its adjoint (hermitian addition) should be used
|
||||||
|
*
|
||||||
|
* return value:
|
||||||
|
* tensor with dimension (x * y * z * ::: * M * K) with solutions
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
#if NOT_EXCLUDED(OP_solve)
|
||||||
|
DECLARE_CUSTOM_OP(solve, 2, 1, true, 0, 0);
|
||||||
|
#endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* lu op. - make LUP decomposition of given batch of 2D square matricies
|
* lu op. - make LUP decomposition of given batch of 2D square matricies
|
||||||
*
|
*
|
||||||
|
|
|
@ -237,11 +237,46 @@ namespace helpers {
|
||||||
samediff::Threads::parallel_tad(loop, currentRow + 1, rowNum, 1);
|
samediff::Threads::parallel_tad(loop, currentRow + 1, rowNum, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static void doolitleLU(LaunchContext* context, NDArray* compound, Nd4jLong rowNum) {
|
||||||
|
auto input = compound->dup();
|
||||||
|
compound->nullify();
|
||||||
|
|
||||||
|
// Decomposing matrix into Upper and Lower
|
||||||
|
// triangular matrix
|
||||||
|
for (auto i = 0; i < rowNum; i++) {
|
||||||
|
|
||||||
|
// Upper Triangular
|
||||||
|
for (auto k = i; k < rowNum; k++) {
|
||||||
|
|
||||||
|
// Summation of L(i, j) * U(j, k)
|
||||||
|
int sum = 0;
|
||||||
|
for (int j = 0; j < i; j++)
|
||||||
|
sum += compound->t<T>(i,j) * compound->t<T>(j,k);
|
||||||
|
|
||||||
|
// Evaluating U(i, k)
|
||||||
|
compound->t<T>(i, k) = input.t<T>(i, k) - sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Lower Triangular
|
||||||
|
for (int k = i + 1; k < rowNum; k++) {
|
||||||
|
// Summation of L(k, j) * U(j, i)
|
||||||
|
int sum = 0;
|
||||||
|
for (int j = 0; j < i; j++)
|
||||||
|
sum += compound->t<T>(k,j) * compound->t<T>(j, i);
|
||||||
|
|
||||||
|
// Evaluating L(k, i)
|
||||||
|
compound->t<T>(k, i) = (input.t<T>(k, i) - sum) / compound->t<T>(i,i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T, typename I>
|
template <typename T, typename I>
|
||||||
static void luNN_(LaunchContext *context, NDArray* compound, NDArray* permutation, Nd4jLong rowNum) {
|
static void luNN_(LaunchContext *context, NDArray* compound, NDArray* permutation, Nd4jLong rowNum) {
|
||||||
|
|
||||||
//const int rowNum = compound->rows();
|
//const int rowNum = compound->rows();
|
||||||
// const int columnNum = output->columns();
|
// const int columnNum = output->columns();
|
||||||
|
if (permutation) { // LUP algorithm
|
||||||
permutation->linspace(0);
|
permutation->linspace(0);
|
||||||
auto permutationBuf = permutation->bufferAsT<I>(); //dataBuffer()->primaryAsT<I>();
|
auto permutationBuf = permutation->bufferAsT<I>(); //dataBuffer()->primaryAsT<I>();
|
||||||
auto compoundBuf = compound->bufferAsT<T>();
|
auto compoundBuf = compound->bufferAsT<T>();
|
||||||
|
@ -252,12 +287,17 @@ namespace helpers {
|
||||||
if (pivotIndex < 0) {
|
if (pivotIndex < 0) {
|
||||||
throw std::runtime_error("helpers::luNN_: input matrix is singular.");
|
throw std::runtime_error("helpers::luNN_: input matrix is singular.");
|
||||||
}
|
}
|
||||||
math::nd4j_swap(permutationBuf[shape::getIndexOffset(i, permutationShape)], permutationBuf[shape::getIndexOffset(pivotIndex, permutationShape)]);
|
math::nd4j_swap(permutationBuf[shape::getIndexOffset(i, permutationShape)],
|
||||||
|
permutationBuf[shape::getIndexOffset(pivotIndex, permutationShape)]);
|
||||||
swapRows(compoundBuf, compoundShape, i, pivotIndex);
|
swapRows(compoundBuf, compoundShape, i, pivotIndex);
|
||||||
|
|
||||||
processColumns(i, rowNum, compoundBuf, compoundShape);
|
processColumns(i, rowNum, compoundBuf, compoundShape);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
else { // Doolitle algorithm with LU decomposition
|
||||||
|
doolitleLU<T>(context, compound, rowNum);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T, typename I>
|
template <typename T, typename I>
|
||||||
static void lu_(LaunchContext * context, NDArray* input, NDArray* output, NDArray* permutationVectors) {
|
static void lu_(LaunchContext * context, NDArray* input, NDArray* output, NDArray* permutationVectors) {
|
||||||
|
@ -265,17 +305,20 @@ namespace helpers {
|
||||||
|
|
||||||
output->assign(input); // fill up output tensor with zeros
|
output->assign(input); // fill up output tensor with zeros
|
||||||
ResultSet outputs = output->allTensorsAlongDimension({-2, -1});
|
ResultSet outputs = output->allTensorsAlongDimension({-2, -1});
|
||||||
ResultSet permutations = permutationVectors->allTensorsAlongDimension({-1});
|
ResultSet permutations;
|
||||||
|
if (permutationVectors)
|
||||||
|
permutations = permutationVectors->allTensorsAlongDimension({-1});
|
||||||
|
|
||||||
auto loop = PRAGMA_THREADS_FOR {
|
auto loop = PRAGMA_THREADS_FOR {
|
||||||
for (auto i = start; i < stop; i += increment) {
|
for (auto i = start; i < stop; i += increment) {
|
||||||
luNN_<T, I>(context, outputs.at(i), permutations.at(i), n);
|
luNN_<T, I>(context, outputs.at(i), permutationVectors?permutations.at(i):nullptr, n);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
samediff::Threads::parallel_for(loop, 0, outputs.size(), 1);
|
samediff::Threads::parallel_for(loop, 0, outputs.size(), 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
void lu(LaunchContext *context, NDArray* input, NDArray* output, NDArray* permutation) {
|
void lu(LaunchContext *context, NDArray* input, NDArray* output, NDArray* permutation) {
|
||||||
BUILD_DOUBLE_SELECTOR(input->dataType(), permutation->dataType(), lu_, (context, input, output, permutation), FLOAT_TYPES, INDEXING_TYPES);
|
BUILD_DOUBLE_SELECTOR(input->dataType(), permutation?permutation->dataType():DataType::INT32, lu_, (context, input, output, permutation), FLOAT_TYPES, INDEXING_TYPES);
|
||||||
}
|
}
|
||||||
|
|
||||||
// BUILD_DOUBLE_TEMPLATE(template NDArray lu_, (LaunchContext *context, NDArray* input, NDArray* output, NDArray* permutation), FLOAT_TYPES, INDEXING_TYPES);
|
// BUILD_DOUBLE_TEMPLATE(template NDArray lu_, (LaunchContext *context, NDArray* input, NDArray* output, NDArray* permutation), FLOAT_TYPES, INDEXING_TYPES);
|
||||||
|
|
|
@ -0,0 +1,100 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit, K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author GS <sgazeos@gmail.com>
|
||||||
|
//
|
||||||
|
#include <op_boilerplate.h>
|
||||||
|
#include <NDArray.h>
|
||||||
|
#include <NDArrayFactory.h>
|
||||||
|
#include <execution/Threads.h>
|
||||||
|
#include <helpers/MmulHelper.h>
|
||||||
|
|
||||||
|
#include "../triangular_solve.h"
|
||||||
|
#include "../lup.h"
|
||||||
|
#include "../solve.h"
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------------------------------------------------------------------- //
|
||||||
|
template <typename T>
|
||||||
|
static void adjointMatrix_(nd4j::LaunchContext* context, NDArray const* input, NDArray* output) {
|
||||||
|
auto inputPart = input->allTensorsAlongDimension({-2, -1});
|
||||||
|
auto outputPart = output->allTensorsAlongDimension({-2, -1});
|
||||||
|
output->assign(input);
|
||||||
|
auto batchLoop = PRAGMA_THREADS_FOR {
|
||||||
|
for (auto batch = start; batch < stop; batch += increment) {
|
||||||
|
for (auto r = 0; r < input->rows(); r++) {
|
||||||
|
for (auto c = 0; c < r; c++) {
|
||||||
|
math::nd4j_swap(outputPart[batch]->t<T>(r, c) , outputPart[batch]->t<T>(c, r));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
samediff::Threads::parallel_tad(batchLoop, 0, inputPart.size(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------------------------------------------------------------------- //
|
||||||
|
template <typename T>
|
||||||
|
static int solveFunctor_(nd4j::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool const adjoint, NDArray* output) {
|
||||||
|
|
||||||
|
// stage 1: LU decomposition batched
|
||||||
|
auto leftOutput = leftInput->ulike();
|
||||||
|
auto permuShape = rightInput->getShapeAsVector(); permuShape.pop_back();
|
||||||
|
auto permutations = NDArrayFactory::create<int>('c', permuShape, context);
|
||||||
|
helpers::lu(context, leftInput, &leftOutput, &permutations);
|
||||||
|
auto P = leftInput->ulike(); //permutations batched matrix
|
||||||
|
P.nullify(); // to fill up matricies with zeros
|
||||||
|
auto PPart = P.allTensorsAlongDimension({-2,-1});
|
||||||
|
auto permutationsPart = permutations.allTensorsAlongDimension({-1});
|
||||||
|
|
||||||
|
for (auto batch = 0; batch < permutationsPart.size(); ++batch) {
|
||||||
|
for (auto row = 0; row < PPart[batch]->rows(); ++row) {
|
||||||
|
PPart[batch]->t<T>(row, permutationsPart[batch]->t<int>(row)) = T(1.f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto leftLower = leftOutput.dup();
|
||||||
|
auto rightOutput = rightInput->ulike();
|
||||||
|
auto rightPermuted = rightOutput.ulike();
|
||||||
|
MmulHelper::matmul(&P, rightInput, &rightPermuted, 0, 0);
|
||||||
|
ResultSet leftLowerPart = leftLower.allTensorsAlongDimension({-2, -1});
|
||||||
|
for (auto i = 0; i < leftLowerPart.size(); i++) {
|
||||||
|
for (auto r = 0; r < leftLowerPart[i]->rows(); r++)
|
||||||
|
leftLowerPart[i]->t<T>(r,r) = (T)1.f;
|
||||||
|
}
|
||||||
|
// stage 2: triangularSolveFunctor for Lower with given b
|
||||||
|
helpers::triangularSolveFunctor(context, &leftLower, &rightPermuted, true, false, &rightOutput);
|
||||||
|
// stage 3: triangularSolveFunctor for Upper with output of previous stage
|
||||||
|
helpers::triangularSolveFunctor(context, &leftOutput, &rightOutput, false, false, output);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------------------------------------------------------------------- //
|
||||||
|
int solveFunctor(nd4j::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool const adjoint, NDArray* output) {
|
||||||
|
BUILD_SINGLE_SELECTOR(leftInput->dataType(), return solveFunctor_, (context, leftInput, rightInput, adjoint, output), FLOAT_TYPES);
|
||||||
|
}
|
||||||
|
// --------------------------------------------------------------------------------------------------------------------------------------- //
|
||||||
|
void adjointMatrix(nd4j::LaunchContext* context, NDArray const* input, NDArray* output) {
|
||||||
|
BUILD_SINGLE_SELECTOR(input->dataType(), adjointMatrix_, (context, input, output), FLOAT_TYPES);
|
||||||
|
}
|
||||||
|
// --------------------------------------------------------------------------------------------------------------------------------------- //
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -41,13 +41,16 @@ namespace helpers {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static void lowerTriangularSolve(nd4j::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool adjoint, NDArray* output) {
|
static void lowerTriangularSolve(nd4j::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool adjoint, NDArray* output) {
|
||||||
auto rows = leftInput->rows();
|
auto rows = leftInput->rows();
|
||||||
|
auto cols = rightInput->columns();
|
||||||
//output->t<T>(0,0) = rightInput->t<T>(0,0) / leftInput->t<T>(0,0);
|
//output->t<T>(0,0) = rightInput->t<T>(0,0) / leftInput->t<T>(0,0);
|
||||||
for (auto r = 0; r < rows; r++) {
|
for (auto r = 0; r < rows; r++) {
|
||||||
auto sum = rightInput->t<T>(r, 0);
|
for (auto j = 0; j < cols; j++) {
|
||||||
|
auto sum = rightInput->t<T>(r, j);
|
||||||
for (auto c = 0; c < r; c++) {
|
for (auto c = 0; c < r; c++) {
|
||||||
sum -= leftInput->t<T>(r,c) * output->t<T>(c, 0);
|
sum -= leftInput->t<T>(r, c) * output->t<T>(c, j);
|
||||||
|
}
|
||||||
|
output->t<T>(r, j) = sum / leftInput->t<T>(r, r);
|
||||||
}
|
}
|
||||||
output->t<T>(r, 0) = sum / leftInput->t<T>(r, r);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -68,13 +71,15 @@ namespace helpers {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static void upperTriangularSolve(nd4j::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool adjoint, NDArray* output) {
|
static void upperTriangularSolve(nd4j::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool adjoint, NDArray* output) {
|
||||||
auto rows = leftInput->rows();
|
auto rows = leftInput->rows();
|
||||||
|
auto cols = rightInput->columns();
|
||||||
for (auto r = rows; r > 0; r--) {
|
for (auto r = rows; r > 0; r--) {
|
||||||
auto sum = rightInput->t<T>(r - 1, 0);
|
for (auto j = 0; j < cols; j++) {
|
||||||
|
auto sum = rightInput->t<T>(r - 1, j);
|
||||||
for (auto c = r; c < rows; c++) {
|
for (auto c = r; c < rows; c++) {
|
||||||
sum -= leftInput->t<T>(r - 1, c) * output->t<T>(c, 0);
|
sum -= leftInput->t<T>(r - 1, c) * output->t<T>(c, j);
|
||||||
|
}
|
||||||
|
output->t<T>(r - 1, j) = sum / leftInput->t<T>(r - 1, r - 1);
|
||||||
}
|
}
|
||||||
output->t<T>(r - 1, 0) = sum / leftInput->t<T>(r - 1, r - 1);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue