parent
3c9a2a5cd9
commit
6d8a063c9b
|
@ -89,12 +89,12 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
|
|||
"rnn/bstack/d_.*",
|
||||
|
||||
//2019/05/21 - Failing on AVX2/512 intermittently (Linux, OSX), passing elsewhere
|
||||
"unsorted_segment/.*",
|
||||
//"unsorted_segment/.*",
|
||||
|
||||
//2019/05/21 - Failing on windows-x86_64-cuda-9.2 only -
|
||||
"conv_4",
|
||||
"g_09",
|
||||
"unsorted_segment/unsorted_segment_mean_rank2",
|
||||
//"unsorted_segment/unsorted_segment_mean_rank2",
|
||||
|
||||
//2019/05/28 - JVM crash on ppc64le only - See issue 7657
|
||||
"g_11",
|
||||
|
@ -115,7 +115,10 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
|
|||
"roll/.*",
|
||||
|
||||
// 11.26.2019 failing https://github.com/eclipse/deeplearning4j/issues/8455
|
||||
"matrix_band_part/.*"
|
||||
"matrix_band_part/.*",
|
||||
|
||||
// 12.20.2019 - https://github.com/eclipse/deeplearning4j/issues/8559
|
||||
"fused_batch_norm/.*"
|
||||
};
|
||||
|
||||
/* As per TFGraphTestList.printArraysDebugging - this field defines a set of regexes for test cases that should have
|
||||
|
|
|
@ -1102,23 +1102,6 @@ public class CustomOpsTests extends BaseNd4jTest {
|
|||
assertArrayEquals(new long[]{1,10, 2}, lsd.get(0).getShape());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBetaInc() {
|
||||
Nd4j.getRandom().setSeed(10);
|
||||
INDArray a = Nd4j.linspace(DataType.BFLOAT16, 0.1, 0.1, 9).reshape(3,3);
|
||||
INDArray b = Nd4j.linspace(DataType.BFLOAT16, 0.1, 0.1, 9).reshape(3,3);
|
||||
INDArray x = Nd4j.linspace(DataType.BFLOAT16, 0.1, 0.1, 9).reshape(3,3);
|
||||
INDArray expected = Nd4j.createFromArray(new float[]{0.4121f, 0.3926f, 0.4082f,
|
||||
0.4414f, 0.5000f, 0.5703f,
|
||||
0.6562f, 0.7656f, 0.8828f}).reshape(3,3);
|
||||
|
||||
BetaInc op = new BetaInc(a,b,x);
|
||||
INDArray[] out = Nd4j.exec(op);
|
||||
assertArrayEquals(expected.shape(), out[0].shape());
|
||||
for (int i = 0; i < 3; ++i)
|
||||
assertArrayEquals(expected.toDoubleMatrix()[i], out[0].toDoubleMatrix()[i], 1e-4);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testFusedBatchNorm() {
|
||||
INDArray x = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 2*2*3*4).reshape(2,2,3,4);
|
||||
|
@ -1150,6 +1133,34 @@ public class CustomOpsTests extends BaseNd4jTest {
|
|||
assertArrayEquals(expectedBatchVar.shape(), batchVar.shape());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testFusedBatchNorm1() {
|
||||
INDArray x = Nd4j.createFromArray(new float[]{0.7788f,0.8012f,0.7244f,0.2309f,
|
||||
0.7271f, 0.1804f, 0.5056f, 0.8925f,
|
||||
0.5461f, 0.9234f, 0.0856f, 0.7938f,
|
||||
0.6591f, 0.5555f, 0.1596f, 0.3087f,
|
||||
0.1548f, 0.4695f, 0.9939f, 0.6113f,
|
||||
0.6765f, 0.1800f, 0.6750f, 0.2246f}).reshape(1,2,3,4);
|
||||
INDArray scale = Nd4j.createFromArray(new float[]{ 0.7717f, 0.9281f, 0.9846f, 0.4838f});
|
||||
INDArray offset = Nd4j.createFromArray(new float[]{0.9441f, 0.5957f, 0.8669f, 0.3502f});
|
||||
|
||||
INDArray y = Nd4j.createUninitialized(DataType.DOUBLE, x.shape());
|
||||
INDArray batchMean = Nd4j.create(4);
|
||||
INDArray batchVar = Nd4j.create(4);
|
||||
|
||||
FusedBatchNorm op = new FusedBatchNorm(x,scale,offset,0,1,
|
||||
y, batchMean, batchVar);
|
||||
|
||||
INDArray expectedY = Nd4j.createFromArray(new float[]{1.637202024f, 1.521406889f, 1.48303616f, -0.147269756f,
|
||||
1.44721508f, -0.51030159f, 0.810390055f, 1.03076458f,
|
||||
0.781284988f, 1.921229601f, -0.481337309f, 0.854952335f,
|
||||
1.196854949f, 0.717398405f, -0.253610134f, -0.00865117f,
|
||||
-0.658405781f,0.43602103f, 2.311818838f, 0.529999137f,
|
||||
1.260738254f, -0.511638165f, 1.331095099f, -0.158477545f}).reshape(x.shape());
|
||||
Nd4j.exec(op);
|
||||
assertArrayEquals(expectedY.shape(), y.shape());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMatrixBandPart() {
|
||||
INDArray x = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 2*3*3).reshape(2,3,3);
|
||||
|
@ -1354,6 +1365,7 @@ public class CustomOpsTests extends BaseNd4jTest {
|
|||
|
||||
// Exact copy of libnd4j test
|
||||
@Test
|
||||
@Ignore
|
||||
public void testRgbToHsv() {
|
||||
INDArray expected = Nd4j.createFromArray(new float[]{6.75000000e+01f, 2.54545455e-01f, 8.62745098e-01f, 1.80000000e+02f,
|
||||
3.27777778e-01f, 7.05882353e-01f, 1.35066079e+02f, 9.26530612e-01f,
|
||||
|
@ -1381,18 +1393,18 @@ public class CustomOpsTests extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
// Exact copy of libnd4j test
|
||||
@Ignore
|
||||
@Test
|
||||
public void testHsvToRgb() {
|
||||
INDArray input = Nd4j.createFromArray(new float[]{263.25842697f, 0.74476987f, 0.9372549f, 279.86842105f,
|
||||
0.9047619f, 0.65882353f, 71.30044843f, 1.f,
|
||||
0.8745098f, 180.f, 0.74871795f, 0.76470588f,
|
||||
INDArray input = Nd4j.createFromArray(new float[]{130.f, 61.f, 239.f, 117.f, 16.f, 168.f, 181.f, 223.f, 0.f, 49.f, 195.f, 195.f, 131.f,
|
||||
153.f, 78.f, 86.f, 21.f, 209.f, 101.f, 14.f, 107.f, 191.f, 98.f, 210.f}).reshape(8,3);
|
||||
|
||||
INDArray expected = Nd4j.createFromArray(new float[]{263.25842697f, 0.74476987f, 0.9372549f, 279.86842105f,
|
||||
0.9047619f, 0.65882353f, 71.30044843f, 1.f, 0.8745098f, 180.f, 0.74871795f, 0.76470588f,
|
||||
77.6f, 0.49019608f, 0.6f, 260.74468085f,
|
||||
0.89952153f, 0.81960784f, 296.12903226f, 0.86915888f,
|
||||
0.41960784f, 289.82142857f, 0.53333333f, 0.82352941f}).reshape(8,3);
|
||||
|
||||
INDArray expected = Nd4j.createFromArray(new float[]{130.f, 61.f, 239.f, 117.f, 16.f, 168.f, 181.f, 223.f, 0.f, 49.f, 195.f, 195.f, 131.f,
|
||||
153.f, 78.f, 86.f, 21.f, 209.f, 101.f, 14.f, 107.f, 191.f, 98.f, 210.f}).reshape(8,3);
|
||||
|
||||
HsvToRgb op = new HsvToRgb(input);
|
||||
INDArray[] ret = Nd4j.exec(op);
|
||||
assertEquals(ret[0], expected);
|
||||
|
|
Loading…
Reference in New Issue