Update allowed data types in pooling
parent
53bfdb9994
commit
968eaad2dd
|
@ -218,7 +218,7 @@ namespace sd {
|
||||||
}
|
}
|
||||||
|
|
||||||
void ConvolutionUtils::pooling2d(sd::graph::Context& block, const NDArray& input, NDArray& output, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const PoolingType poolingMode, const int extraParam0) {
|
void ConvolutionUtils::pooling2d(sd::graph::Context& block, const NDArray& input, NDArray& output, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const PoolingType poolingMode, const int extraParam0) {
|
||||||
BUILD_SINGLE_SELECTOR(input.dataType(), pooling2d_, (block, input, output, kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0), FLOAT_TYPES);
|
BUILD_SINGLE_SELECTOR(input.dataType(), pooling2d_, (block, input, output, kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0), NUMERIC_TYPES);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -301,7 +301,7 @@ namespace sd {
|
||||||
}
|
}
|
||||||
|
|
||||||
void ConvolutionUtils::pooling2dBP(sd::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0) {
|
void ConvolutionUtils::pooling2dBP(sd::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0) {
|
||||||
BUILD_SINGLE_SELECTOR(input.dataType(), pooling2dBP_, (block, input, gradO, gradI, kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0), FLOAT_TYPES);
|
BUILD_SINGLE_SELECTOR(input.dataType(), pooling2dBP_, (block, input, gradO, gradI, kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0), NUMERIC_TYPES);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -315,11 +315,11 @@ void ConvolutionUtils::pooling2d(sd::graph::Context& block, const NDArray& input
|
||||||
switch (poolingMode) {
|
switch (poolingMode) {
|
||||||
|
|
||||||
case MAX_POOL: {
|
case MAX_POOL: {
|
||||||
BUILD_SINGLE_SELECTOR_TWICE(input.dataType(), maxPooling2dCudaLauncher, (*block.launchContext(), input.specialBuffer(), input.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), kH, kW, sH, sW, pH, pW, dH, dW, extraParam0), FLOAT_TYPES);
|
BUILD_SINGLE_SELECTOR_TWICE(input.dataType(), maxPooling2dCudaLauncher, (*block.launchContext(), input.specialBuffer(), input.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), kH, kW, sH, sW, pH, pW, dH, dW, extraParam0), NUMERIC_TYPES);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case AVG_POOL: {
|
case AVG_POOL: {
|
||||||
BUILD_SINGLE_SELECTOR_TWICE(input.dataType(), avgPooling2dCudaLauncher, (*block.launchContext(), input.specialBuffer(), input.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), kH, kW, sH, sW, pH, pW, dH, dW, extraParam0), FLOAT_TYPES);
|
BUILD_SINGLE_SELECTOR_TWICE(input.dataType(), avgPooling2dCudaLauncher, (*block.launchContext(), input.specialBuffer(), input.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), kH, kW, sH, sW, pH, pW, dH, dW, extraParam0), NUMERIC_TYPES);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case PNORM_POOL: {
|
case PNORM_POOL: {
|
||||||
|
|
|
@ -178,7 +178,7 @@ void ConvolutionUtils::pooling2dBP(sd::graph::Context& block, const NDArray& inp
|
||||||
const int sharedMem = gradO.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128;
|
const int sharedMem = gradO.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128;
|
||||||
|
|
||||||
NDArray::prepareSpecialUse({&gradI}, {&input, &gradO});
|
NDArray::prepareSpecialUse({&gradI}, {&input, &gradO});
|
||||||
BUILD_SINGLE_SELECTOR(input.dataType(), pooling2dBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), input.specialBuffer(), input.specialShapeInfo(), gradO.specialBuffer(), gradO.specialShapeInfo(), gradI.specialBuffer(), gradI.specialShapeInfo(), kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0), FLOAT_TYPES);
|
BUILD_SINGLE_SELECTOR(input.dataType(), pooling2dBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), input.specialBuffer(), input.specialShapeInfo(), gradO.specialBuffer(), gradO.specialShapeInfo(), gradI.specialBuffer(), gradI.specialShapeInfo(), kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0), NUMERIC_TYPES);
|
||||||
NDArray::registerSpecialUse({&gradI}, {&input, &gradO});
|
NDArray::registerSpecialUse({&gradI}, {&input, &gradO});
|
||||||
|
|
||||||
manager.synchronize();
|
manager.synchronize();
|
||||||
|
|
|
@ -1,5 +1,2 @@
|
||||||
Identity,Variable/read
|
Identity,in_0/read
|
||||||
Identity,Variable_1/read
|
MaxPoolWithArgmax,MaxPoolWithArgmax
|
||||||
Pack,floordiv/x
|
|
||||||
Pack,floordiv/y
|
|
||||||
FloorDiv,floordiv
|
|
||||||
|
|
|
@ -69,14 +69,11 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
|
||||||
* the status of the test failing. No tests will run.
|
* the status of the test failing. No tests will run.
|
||||||
*/
|
*/
|
||||||
public final static List<String> EXECUTE_ONLY_MODELS = Arrays.asList(
|
public final static List<String> EXECUTE_ONLY_MODELS = Arrays.asList(
|
||||||
// "max_pool_with_argmax/int32_int64_padding_SAME",
|
"max_pool_with_argmax/int32_int64_padding_SAME",
|
||||||
// "fused_batch_norm/float32_nhwc",
|
// "fused_batch_norm/float32_nhwc",
|
||||||
// "max_pool_with_argmax/int64_int64_padding_SAME",
|
"max_pool_with_argmax/int64_int64_padding_SAME"
|
||||||
// "fused_batch_norm/float16_nhwc",
|
// "fused_batch_norm/float16_nhwc",
|
||||||
"roll/rank3_int32_axis",
|
|
||||||
"roll/rank3_int32_axis",
|
|
||||||
"roll/rank2_float32_zeroshift",
|
|
||||||
"roll/rank3_float64_axis"
|
|
||||||
);
|
);
|
||||||
|
|
||||||
public static final String[] IGNORE_REGEXES = new String[]{
|
public static final String[] IGNORE_REGEXES = new String[]{
|
||||||
|
|
|
@ -1,2 +1,3 @@
|
||||||
in_0/read,in_0/read
|
in_0/read,in_0/read
|
||||||
Roll,Roll
|
MaxPoolWithArgmax,MaxPoolWithArgmax
|
||||||
|
MaxPoolWithArgmax:1,MaxPoolWithArgmax
|
||||||
|
|
|
@ -99,13 +99,6 @@ class TestTensorflowIR {
|
||||||
val output2 = importedGraph.outputAll(inputMap)
|
val output2 = importedGraph.outputAll(inputMap)
|
||||||
|
|
||||||
|
|
||||||
val matrix =
|
|
||||||
TensorflowIRTensor(tensorflowIRGraph.nodeByName("in_0").attrMap["value"]!!.tensor).toNd4jNDArray()
|
|
||||||
val roll2 = Roll(matrix, Nd4j.scalar(2), Nd4j.scalar(1))
|
|
||||||
val outputs = Nd4j.exec(roll2)[0]
|
|
||||||
val tfOutputRoll = tfOutput["Roll"]
|
|
||||||
val nd4jOutput = output["Roll"]
|
|
||||||
|
|
||||||
//assertEquals(tfOutput.keys,outputList)
|
//assertEquals(tfOutput.keys,outputList)
|
||||||
//assertEquals(tfOutput.keys,output2.keys)
|
//assertEquals(tfOutput.keys,output2.keys)
|
||||||
val names = tensorflowIRGraph.nodeList().map { input -> input.nodeName() }
|
val names = tensorflowIRGraph.nodeList().map { input -> input.nodeName() }
|
||||||
|
|
Loading…
Reference in New Issue