commit
630bb3c9b6
|
@ -50,6 +50,7 @@ import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
|
import java.io.IOException;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
import static org.junit.Assert.*;
|
import static org.junit.Assert.*;
|
||||||
|
@ -816,6 +817,37 @@ public class Word2VecTests extends BaseDL4JTest {
|
||||||
assertEquals(vec1.getWordVectorMatrix("money"), vec2.getWordVectorMatrix("money"));
|
assertEquals(vec1.getWordVectorMatrix("money"), vec2.getWordVectorMatrix("money"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testWordsNearestSum() throws IOException {
|
||||||
|
log.info("Load & Vectorize Sentences....");
|
||||||
|
SentenceIterator iter = new BasicLineIterator(inputFile);
|
||||||
|
TokenizerFactory t = new DefaultTokenizerFactory();
|
||||||
|
t.setTokenPreProcessor(new CommonPreprocessor());
|
||||||
|
|
||||||
|
log.info("Building model....");
|
||||||
|
Word2Vec vec = new Word2Vec.Builder()
|
||||||
|
.minWordFrequency(5)
|
||||||
|
.iterations(1)
|
||||||
|
.layerSize(100)
|
||||||
|
.seed(42)
|
||||||
|
.windowSize(5)
|
||||||
|
.iterate(iter)
|
||||||
|
.tokenizerFactory(t)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
log.info("Fitting Word2Vec model....");
|
||||||
|
vec.fit();
|
||||||
|
log.info("Writing word vectors to text file....");
|
||||||
|
log.info("Closest Words:");
|
||||||
|
Collection<String> lst = vec.wordsNearestSum("day", 10);
|
||||||
|
log.info("10 Words closest to 'day': {}", lst);
|
||||||
|
assertTrue(lst.contains("week"));
|
||||||
|
assertTrue(lst.contains("night"));
|
||||||
|
assertTrue(lst.contains("year"));
|
||||||
|
assertTrue(lst.contains("years"));
|
||||||
|
assertTrue(lst.contains("time"));
|
||||||
|
}
|
||||||
|
|
||||||
private static void printWords(String target, Collection<String> list, Word2Vec vec) {
|
private static void printWords(String target, Collection<String> list, Word2Vec vec) {
|
||||||
System.out.println("Words close to [" + target + "]:");
|
System.out.println("Words close to [" + target + "]:");
|
||||||
for (String word : list) {
|
for (String word : list) {
|
||||||
|
|
|
@ -351,7 +351,8 @@ public class BasicModelUtils<T extends SequenceElement> implements ModelUtils<T>
|
||||||
if (lookupTable instanceof InMemoryLookupTable) {
|
if (lookupTable instanceof InMemoryLookupTable) {
|
||||||
InMemoryLookupTable l = (InMemoryLookupTable) lookupTable;
|
InMemoryLookupTable l = (InMemoryLookupTable) lookupTable;
|
||||||
INDArray syn0 = l.getSyn0();
|
INDArray syn0 = l.getSyn0();
|
||||||
INDArray weights = syn0.norm2(0).rdivi(1).muli(words);
|
INDArray temp = syn0.norm2(0).rdivi(1).reshape(words.shape());
|
||||||
|
INDArray weights = temp.muli(words);
|
||||||
INDArray distances = syn0.mulRowVector(weights).sum(1);
|
INDArray distances = syn0.mulRowVector(weights).sum(1);
|
||||||
INDArray[] sorted = Nd4j.sortWithIndices(distances, 0, false);
|
INDArray[] sorted = Nd4j.sortWithIndices(distances, 0, false);
|
||||||
INDArray sort = sorted[0];
|
INDArray sort = sorted[0];
|
||||||
|
|
|
@ -47,7 +47,7 @@ CONFIGURABLE_OP_IMPL(adjust_contrast, 1, 1, true, 1, 0) {
|
||||||
NDArray factorT(output->dataType(), block.launchContext()); // = NDArrayFactory::create(factor, block.launchContext());
|
NDArray factorT(output->dataType(), block.launchContext()); // = NDArrayFactory::create(factor, block.launchContext());
|
||||||
factorT.p(0, factor);
|
factorT.p(0, factor);
|
||||||
// this is contrast calculation
|
// this is contrast calculation
|
||||||
*output = (*input - mean) * factorT + mean;
|
output->assign((*input - mean) * factorT + mean);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -33,6 +33,7 @@ import org.nd4j.linalg.api.blas.params.MMulTranspose;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.NoOp;
|
import org.nd4j.linalg.api.ops.NoOp;
|
||||||
|
import org.nd4j.linalg.api.ops.custom.*;
|
||||||
import org.nd4j.linalg.api.ops.impl.broadcast.BiasAdd;
|
import org.nd4j.linalg.api.ops.impl.broadcast.BiasAdd;
|
||||||
import org.nd4j.linalg.api.ops.impl.broadcast.BiasAddGrad;
|
import org.nd4j.linalg.api.ops.impl.broadcast.BiasAddGrad;
|
||||||
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Enter;
|
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Enter;
|
||||||
|
@ -2649,6 +2650,33 @@ public class DifferentialFunctionFactory {
|
||||||
return new NextIteration(sameDiff, x).outputVariable();
|
return new NextIteration(sameDiff, x).outputVariable();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public SDVariable adjustContrast(SDVariable in, SDVariable factor) {
|
||||||
|
return new AdjustContrast(sameDiff, in, factor).outputVariable();
|
||||||
|
}
|
||||||
|
|
||||||
|
public SDVariable adjustContrastV2(SDVariable in, SDVariable factor) {
|
||||||
|
return new AdjustContrastV2(sameDiff, in, factor).outputVariable();
|
||||||
|
}
|
||||||
|
|
||||||
|
public SDVariable bitCast(SDVariable in, SDVariable dataType) {
|
||||||
|
return new BitCast(sameDiff, in, dataType).outputVariable();
|
||||||
|
}
|
||||||
|
|
||||||
|
public SDVariable compareAndBitpack(SDVariable threshold) {
|
||||||
|
return new CompareAndBitpack(sameDiff, threshold).outputVariable();
|
||||||
|
}
|
||||||
|
|
||||||
|
public SDVariable divideNoNan(SDVariable in1, SDVariable in2) {
|
||||||
|
return new DivideNoNan(sameDiff, in1, in2).outputVariable();
|
||||||
|
}
|
||||||
|
|
||||||
|
public SDVariable drawBoundingBoxes(SDVariable boxes, SDVariable colors) {
|
||||||
|
return new DrawBoundingBoxes(sameDiff, boxes, colors).outputVariable();
|
||||||
|
}
|
||||||
|
|
||||||
|
public SDVariable fakeQuantWithMinMaxVarsPerChannel(SDVariable x, SDVariable min, SDVariable max) {
|
||||||
|
return new FakeQuantWithMinMaxVarsPerChannel(sameDiff,x,min,max).outputVariable();
|
||||||
|
}
|
||||||
|
|
||||||
public String toString() {
|
public String toString() {
|
||||||
return "DifferentialFunctionFactory{methodNames=" + methodNames + "}";
|
return "DifferentialFunctionFactory{methodNames=" + methodNames + "}";
|
||||||
|
|
|
@ -581,8 +581,14 @@ public class ImportClassMapping {
|
||||||
org.nd4j.linalg.api.ops.random.impl.ProbablisticMerge.class,
|
org.nd4j.linalg.api.ops.random.impl.ProbablisticMerge.class,
|
||||||
org.nd4j.linalg.api.ops.random.impl.Range.class,
|
org.nd4j.linalg.api.ops.random.impl.Range.class,
|
||||||
org.nd4j.linalg.api.ops.random.impl.TruncatedNormalDistribution.class,
|
org.nd4j.linalg.api.ops.random.impl.TruncatedNormalDistribution.class,
|
||||||
org.nd4j.linalg.api.ops.random.impl.UniformDistribution.class
|
org.nd4j.linalg.api.ops.random.impl.UniformDistribution.class,
|
||||||
|
org.nd4j.linalg.api.ops.custom.AdjustContrast.class,
|
||||||
|
org.nd4j.linalg.api.ops.custom.AdjustContrastV2.class,
|
||||||
|
org.nd4j.linalg.api.ops.custom.BitCast.class,
|
||||||
|
org.nd4j.linalg.api.ops.custom.CompareAndBitpack.class,
|
||||||
|
org.nd4j.linalg.api.ops.custom.DivideNoNan.class,
|
||||||
|
org.nd4j.linalg.api.ops.custom.DrawBoundingBoxes.class,
|
||||||
|
org.nd4j.linalg.api.ops.custom.FakeQuantWithMinMaxVarsPerChannel.class
|
||||||
);
|
);
|
||||||
|
|
||||||
static {
|
static {
|
||||||
|
|
|
@ -0,0 +1,30 @@
|
||||||
|
package org.nd4j.linalg.api.ops.custom;
|
||||||
|
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
|
||||||
|
public class AdjustContrast extends BaseAdjustContrast {
|
||||||
|
|
||||||
|
public AdjustContrast() {super();}
|
||||||
|
|
||||||
|
public AdjustContrast(INDArray in, double factor, INDArray out) {
|
||||||
|
super(in, factor, out);
|
||||||
|
}
|
||||||
|
|
||||||
|
public AdjustContrast(SameDiff sameDiff, SDVariable in, SDVariable factor) {
|
||||||
|
super(sameDiff,new SDVariable[]{in,factor});
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String opName() {
|
||||||
|
return "adjust_contrast";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String tensorflowName() {
|
||||||
|
return "AdjustContrast";
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,30 @@
|
||||||
|
package org.nd4j.linalg.api.ops.custom;
|
||||||
|
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
|
||||||
|
public class AdjustContrastV2 extends BaseAdjustContrast {
|
||||||
|
|
||||||
|
public AdjustContrastV2() {super();}
|
||||||
|
|
||||||
|
public AdjustContrastV2(INDArray in, double factor, INDArray out) {
|
||||||
|
super(in, factor, out);
|
||||||
|
}
|
||||||
|
|
||||||
|
public AdjustContrastV2(SameDiff sameDiff, SDVariable in, SDVariable factor) {
|
||||||
|
super( sameDiff,new SDVariable[]{in,factor});
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String opName() {
|
||||||
|
return "adjust_contrast_v2";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String tensorflowName() {
|
||||||
|
return "AdjustContrast";
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,25 @@
|
||||||
|
package org.nd4j.linalg.api.ops.custom;
|
||||||
|
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
|
||||||
|
public abstract class BaseAdjustContrast extends DynamicCustomOp {
|
||||||
|
public BaseAdjustContrast() {
|
||||||
|
}
|
||||||
|
|
||||||
|
public BaseAdjustContrast(INDArray in, double factor, INDArray out) {
|
||||||
|
Preconditions.checkArgument(in.rank() >= 3,
|
||||||
|
String.format("AdjustContrast: op expects rank of input array to be >= 3, but got %d instead", in.rank()));
|
||||||
|
inputArguments.add(in);
|
||||||
|
outputArguments.add(out);
|
||||||
|
|
||||||
|
addTArgument(factor);
|
||||||
|
}
|
||||||
|
|
||||||
|
public BaseAdjustContrast(SameDiff sameDiff, SDVariable[] vars) {
|
||||||
|
super("", sameDiff, vars);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,32 @@
|
||||||
|
package org.nd4j.linalg.api.ops.custom;
|
||||||
|
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
public class BitCast extends DynamicCustomOp {
|
||||||
|
public BitCast() {}
|
||||||
|
|
||||||
|
public BitCast(INDArray in, int dataType, INDArray out) {
|
||||||
|
inputArguments.add(in);
|
||||||
|
outputArguments.add(out);
|
||||||
|
iArguments.add(Long.valueOf(dataType));
|
||||||
|
}
|
||||||
|
|
||||||
|
public BitCast(SameDiff sameDiff, SDVariable in, SDVariable dataType) {
|
||||||
|
super("", sameDiff, new SDVariable[]{in, dataType});
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String opName() {
|
||||||
|
return "bitcast";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String tensorflowName() {
|
||||||
|
return "Bitcast";
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,31 @@
|
||||||
|
package org.nd4j.linalg.api.ops.custom;
|
||||||
|
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
public class CompareAndBitpack extends DynamicCustomOp {
|
||||||
|
public CompareAndBitpack() {}
|
||||||
|
|
||||||
|
public CompareAndBitpack(INDArray in, double threshold, INDArray out) {
|
||||||
|
inputArguments.add(in);
|
||||||
|
inputArguments.add(Nd4j.scalar(threshold));
|
||||||
|
outputArguments.add(out);
|
||||||
|
}
|
||||||
|
|
||||||
|
public CompareAndBitpack(SameDiff sameDiff, SDVariable threshold) {
|
||||||
|
super("", sameDiff, new SDVariable[]{threshold});
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String opName() {
|
||||||
|
return "compare_and_bitpack";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String tensorflowName() {
|
||||||
|
return "CompareAndBitpack";
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,32 @@
|
||||||
|
package org.nd4j.linalg.api.ops.custom;
|
||||||
|
|
||||||
|
import org.apache.commons.math3.analysis.function.Divide;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
|
||||||
|
public class DivideNoNan extends DynamicCustomOp {
|
||||||
|
public DivideNoNan() {
|
||||||
|
}
|
||||||
|
|
||||||
|
public DivideNoNan(INDArray in1, INDArray in2, INDArray out) {
|
||||||
|
inputArguments.add(in1);
|
||||||
|
inputArguments.add(in2);
|
||||||
|
outputArguments.add(out);
|
||||||
|
}
|
||||||
|
|
||||||
|
public DivideNoNan(SameDiff sameDiff, SDVariable in1, SDVariable in2) {
|
||||||
|
super("", sameDiff, new SDVariable[]{in1, in2});
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String opName() {
|
||||||
|
return "divide_no_nan";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String tensorflowName() {
|
||||||
|
return "DivNoNan";
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,32 @@
|
||||||
|
package org.nd4j.linalg.api.ops.custom;
|
||||||
|
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
|
||||||
|
public class DrawBoundingBoxes extends DynamicCustomOp {
|
||||||
|
public DrawBoundingBoxes() {}
|
||||||
|
|
||||||
|
public DrawBoundingBoxes(INDArray images, INDArray boxes, INDArray colors,
|
||||||
|
INDArray output) {
|
||||||
|
inputArguments.add(images);
|
||||||
|
inputArguments.add(boxes);
|
||||||
|
inputArguments.add(colors);
|
||||||
|
outputArguments.add(output);
|
||||||
|
}
|
||||||
|
|
||||||
|
public DrawBoundingBoxes(SameDiff sameDiff, SDVariable boxes, SDVariable colors) {
|
||||||
|
super("", sameDiff, new SDVariable[]{boxes, colors});
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String opName() {
|
||||||
|
return "draw_bounding_boxes";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String tensorflowName() {
|
||||||
|
return "DrawBoundingBoxes";
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,36 @@
|
||||||
|
package org.nd4j.linalg.api.ops.custom;
|
||||||
|
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
|
||||||
|
public class FakeQuantWithMinMaxVarsPerChannel extends DynamicCustomOp {
|
||||||
|
public FakeQuantWithMinMaxVarsPerChannel() {}
|
||||||
|
|
||||||
|
public FakeQuantWithMinMaxVarsPerChannel(INDArray x, INDArray min, INDArray max,
|
||||||
|
INDArray output) {
|
||||||
|
Preconditions.checkArgument(min.isVector() && max.isVector() &&
|
||||||
|
min.length() == max.length(),
|
||||||
|
"FakeQuantWithMinMaxVarsPerChannel: min and max should be 1D tensors with the same length");
|
||||||
|
inputArguments.add(x);
|
||||||
|
inputArguments.add(min);
|
||||||
|
inputArguments.add(max);
|
||||||
|
outputArguments.add(output);
|
||||||
|
}
|
||||||
|
|
||||||
|
public FakeQuantWithMinMaxVarsPerChannel(SameDiff sameDiff, SDVariable x, SDVariable min, SDVariable max) {
|
||||||
|
super("", sameDiff, new SDVariable[]{x, min, max});
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String opName() {
|
||||||
|
return "fake_quant_with_min_max_vars_per_channel";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String tensorflowName() {
|
||||||
|
return "FakeQuantWithMinMaxVarsPerChannel";
|
||||||
|
}
|
||||||
|
}
|
|
@ -26,8 +26,7 @@ import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.CustomOp;
|
import org.nd4j.linalg.api.ops.CustomOp;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.api.ops.custom.Flatten;
|
import org.nd4j.linalg.api.ops.custom.*;
|
||||||
import org.nd4j.linalg.api.ops.custom.ScatterUpdate;
|
|
||||||
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
||||||
import org.nd4j.linalg.api.ops.executioner.OpStatus;
|
import org.nd4j.linalg.api.ops.executioner.OpStatus;
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce.Mmul;
|
import org.nd4j.linalg.api.ops.impl.reduce.Mmul;
|
||||||
|
@ -807,4 +806,118 @@ public class CustomOpsTests extends BaseNd4jTest {
|
||||||
|
|
||||||
Nd4j.getExecutioner().commit();
|
Nd4j.getExecutioner().commit();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testAdjustContrast() {
|
||||||
|
INDArray in = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 4*4*3).reshape(4,4,3);
|
||||||
|
INDArray out = Nd4j.zeros(4,4,3);
|
||||||
|
|
||||||
|
INDArray expected = Nd4j.createFromArray(new double[]{-21.5, -20.5, -19.5, -15.5, -14.5, -13.5, -9.5, -8.5, -7.5, -3.5, -2.5, -1.5,
|
||||||
|
2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, 16.5, 20.5, 21.5, 22.5,
|
||||||
|
26.5, 27.5, 28.5, 32.5, 33.5, 34.5, 38.5, 39.5, 40.5, 44.5, 45.5, 46.5,
|
||||||
|
50.5, 51.5, 52.5, 56.5, 57.5, 58.5, 62.5, 63.5, 64.5, 68.5, 69.5, 70.5
|
||||||
|
}).reshape(4,4,3);
|
||||||
|
Nd4j.exec(new AdjustContrast(in, 2.0, out));
|
||||||
|
|
||||||
|
assertArrayEquals(out.shape(), in.shape());
|
||||||
|
assertEquals(expected, out);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testAdjustContrastV2() {
|
||||||
|
INDArray in = Nd4j.linspace(DataType.DOUBLE,1.0,1.0, 4*4*3).reshape(4,4,3);
|
||||||
|
INDArray out = Nd4j.createUninitialized(4,4,3);
|
||||||
|
|
||||||
|
INDArray expected = Nd4j.createFromArray(new double[]{-21.5, -20.5, -19.5, -15.5, -14.5, -13.5, -9.5, -8.5, -7.5, -3.5, -2.5, -1.5,
|
||||||
|
2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, 16.5, 20.5, 21.5, 22.5,
|
||||||
|
26.5, 27.5, 28.5, 32.5, 33.5, 34.5, 38.5, 39.5, 40.5, 44.5, 45.5, 46.5,
|
||||||
|
50.5, 51.5, 52.5, 56.5, 57.5, 58.5, 62.5, 63.5, 64.5, 68.5, 69.5, 70.5
|
||||||
|
}).reshape(4,4,3);
|
||||||
|
|
||||||
|
Nd4j.exec(new AdjustContrastV2(in, 2.0, out));
|
||||||
|
|
||||||
|
assertArrayEquals(out.shape(), in.shape());
|
||||||
|
assertEquals(expected, out);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testBitCast() {
|
||||||
|
INDArray in = Nd4j.linspace(DataType.FLOAT, 1.0f, 1.0f, 8).reshape(2,2,2);
|
||||||
|
INDArray out = Nd4j.createUninitialized(2,2);
|
||||||
|
|
||||||
|
Nd4j.exec(new BitCast(in, DataType.DOUBLE.toInt(), out));
|
||||||
|
|
||||||
|
INDArray expected = Nd4j.createFromArray(new double[]{2., 512., 8192., 131072.032 }).reshape(2,2);
|
||||||
|
assertArrayEquals(new long[]{2,2}, out.shape());
|
||||||
|
assertEquals(expected, out);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCompareAndBitpack() {
|
||||||
|
INDArray in = Nd4j.createFromArray(new double[]{-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f,
|
||||||
|
-2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}).reshape( 2,3,4);
|
||||||
|
INDArray out = Nd4j.createUninitialized(DataType.UBYTE, 2,3,4);
|
||||||
|
INDArray expected = Nd4j.createFromArray(new byte[]{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1}).
|
||||||
|
reshape(2,3,4);
|
||||||
|
|
||||||
|
Nd4j.exec(new CompareAndBitpack(in ,2.0, out));
|
||||||
|
assertArrayEquals(new long[]{2,3,4}, out.shape());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testDivideNoNan() {
|
||||||
|
INDArray in1 = Nd4j.rand(DataType.DOUBLE, 2,3,4);
|
||||||
|
INDArray in2 = Nd4j.rand(DataType.DOUBLE, 2,3,4);
|
||||||
|
INDArray out = Nd4j.createUninitialized(DataType.DOUBLE, 2,3,4);
|
||||||
|
|
||||||
|
Nd4j.exec(new DivideNoNan(in1, in2, out));
|
||||||
|
assertArrayEquals(new long[]{2,3,4}, out.shape());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testDrawBoundingBoxes() {
|
||||||
|
INDArray images = Nd4j.linspace(DataType.FLOAT, 1.0f, 1.0f, 2*4*5*3).reshape(2,4,5,3);
|
||||||
|
INDArray boxes = Nd4j.createFromArray(new float[]{ 0.0f , 0.0f , 1.0f , 1.0f,
|
||||||
|
0.1f, 0.2f, 0.9f, 0.8f,
|
||||||
|
0.3f, 0.3f, 0.7f, 0.7f,
|
||||||
|
0.4f, 0.4f, 0.6f, 0.6f}).reshape(2,2,4);
|
||||||
|
INDArray colors = Nd4j.createFromArray(new float[]{
|
||||||
|
201.0f, 202.0f, 203.0f, 127.0f, 128.0f, 129.0f}).
|
||||||
|
reshape(2,3);
|
||||||
|
INDArray output = Nd4j.create(DataType.FLOAT, images.shape());
|
||||||
|
INDArray expected = Nd4j.createFromArray(new float[]{127.f, 128.f, 129.f, 127.f, 128.f, 129.f, 127.f, 128.f, 129.f,
|
||||||
|
127.f, 128.f, 129.f, 201.f, 202.f, 203.f,
|
||||||
|
127.f, 128.f, 129.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 127.f, 128.f, 129.f, 201.f, 202.f, 203.f,
|
||||||
|
127.f, 128.f, 129.f, 127.f, 128.f, 129.f, 127.f, 128.f, 129.f, 127.f, 128.f, 129.f, 201.f, 202.f, 203.f,
|
||||||
|
201.f, 202.f, 203.f, 201.f ,202.f ,203.f, 201.f, 202.f, 203.f, 201.f, 202.f, 203.f, 201.f, 202.f, 203.f,
|
||||||
|
|
||||||
|
61.f, 62.f, 63.f, 201.f, 202.f, 203.f, 201.f, 202.f, 203.f, 70.f, 71.f, 72.f, 73.f, 74.f, 75.f,
|
||||||
|
76.f, 77.f, 78.f, 127.f, 128.f, 129.f, 127.f, 128.f, 129.f, 85.f, 86.f, 87.f, 88.f, 89.f, 90.f,
|
||||||
|
91.f, 92.f, 93.f, 201.f, 202.f, 203.f, 201.f, 202.f, 203.f, 100.f, 101.f, 102.f, 103.f, 104.f, 105.f,
|
||||||
|
106.f, 107.f, 108.f, 109.f, 110.f, 111.f, 112.f, 113.f, 114.f, 115.f, 116.f, 117.f, 118.f, 119.f, 120.f}).
|
||||||
|
reshape(2,4,5,3);
|
||||||
|
|
||||||
|
Nd4j.exec(new DrawBoundingBoxes(images, boxes, colors, output));
|
||||||
|
|
||||||
|
assertArrayEquals(images.shape(), output.shape());
|
||||||
|
assertEquals(expected, output);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void FakeQuantWithMinMaxVarsPerChannel() {
|
||||||
|
|
||||||
|
INDArray x = Nd4j.createFromArray(new float[]{-63.80f, -63.75f, -63.4f, -63.5f, 0.0f, 0.1f}).
|
||||||
|
reshape(1,2,3,1);
|
||||||
|
|
||||||
|
INDArray min = Nd4j.createFromArray(new float[]{-63.65f});
|
||||||
|
INDArray max = Nd4j.createFromArray(new float[]{0.1f});
|
||||||
|
|
||||||
|
INDArray output = Nd4j.createUninitialized(DataType.FLOAT, 1,2,3,1);
|
||||||
|
INDArray expected = Nd4j.createFromArray(new float[]{-63.75f, -63.75f, -63.5f, -63.5f, 0.f, 0.f}).
|
||||||
|
reshape(1,2,3,1);
|
||||||
|
|
||||||
|
Nd4j.exec(new FakeQuantWithMinMaxVarsPerChannel(x,min,max,output));
|
||||||
|
|
||||||
|
assertEquals(expected, output);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue