Update allowed data types in pooling

master
agibsonccc 2021-02-07 19:53:55 +09:00
parent 53bfdb9994
commit 968eaad2dd
8 changed files with 12 additions and 24 deletions

View File

@ -218,7 +218,7 @@ namespace sd {
}
void ConvolutionUtils::pooling2d(sd::graph::Context& block, const NDArray& input, NDArray& output, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const PoolingType poolingMode, const int extraParam0) {
BUILD_SINGLE_SELECTOR(input.dataType(), pooling2d_, (block, input, output, kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0), FLOAT_TYPES);
BUILD_SINGLE_SELECTOR(input.dataType(), pooling2d_, (block, input, output, kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0), NUMERIC_TYPES);
}
}

View File

@ -301,7 +301,7 @@ namespace sd {
}
void ConvolutionUtils::pooling2dBP(sd::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0) {
BUILD_SINGLE_SELECTOR(input.dataType(), pooling2dBP_, (block, input, gradO, gradI, kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0), FLOAT_TYPES);
BUILD_SINGLE_SELECTOR(input.dataType(), pooling2dBP_, (block, input, gradO, gradI, kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0), NUMERIC_TYPES);
}
}

View File

@ -315,11 +315,11 @@ void ConvolutionUtils::pooling2d(sd::graph::Context& block, const NDArray& input
switch (poolingMode) {
case MAX_POOL: {
BUILD_SINGLE_SELECTOR_TWICE(input.dataType(), maxPooling2dCudaLauncher, (*block.launchContext(), input.specialBuffer(), input.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), kH, kW, sH, sW, pH, pW, dH, dW, extraParam0), FLOAT_TYPES);
BUILD_SINGLE_SELECTOR_TWICE(input.dataType(), maxPooling2dCudaLauncher, (*block.launchContext(), input.specialBuffer(), input.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), kH, kW, sH, sW, pH, pW, dH, dW, extraParam0), NUMERIC_TYPES);
}
break;
case AVG_POOL: {
BUILD_SINGLE_SELECTOR_TWICE(input.dataType(), avgPooling2dCudaLauncher, (*block.launchContext(), input.specialBuffer(), input.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), kH, kW, sH, sW, pH, pW, dH, dW, extraParam0), FLOAT_TYPES);
BUILD_SINGLE_SELECTOR_TWICE(input.dataType(), avgPooling2dCudaLauncher, (*block.launchContext(), input.specialBuffer(), input.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), kH, kW, sH, sW, pH, pW, dH, dW, extraParam0), NUMERIC_TYPES);
}
break;
case PNORM_POOL: {

View File

@ -178,7 +178,7 @@ void ConvolutionUtils::pooling2dBP(sd::graph::Context& block, const NDArray& inp
const int sharedMem = gradO.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128;
NDArray::prepareSpecialUse({&gradI}, {&input, &gradO});
BUILD_SINGLE_SELECTOR(input.dataType(), pooling2dBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), input.specialBuffer(), input.specialShapeInfo(), gradO.specialBuffer(), gradO.specialShapeInfo(), gradI.specialBuffer(), gradI.specialShapeInfo(), kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0), FLOAT_TYPES);
BUILD_SINGLE_SELECTOR(input.dataType(), pooling2dBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), input.specialBuffer(), input.specialShapeInfo(), gradO.specialBuffer(), gradO.specialShapeInfo(), gradI.specialBuffer(), gradI.specialShapeInfo(), kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0), NUMERIC_TYPES);
NDArray::registerSpecialUse({&gradI}, {&input, &gradO});
manager.synchronize();

View File

@ -1,5 +1,2 @@
Identity,Variable/read
Identity,Variable_1/read
Pack,floordiv/x
Pack,floordiv/y
FloorDiv,floordiv
Identity,in_0/read
MaxPoolWithArgmax,MaxPoolWithArgmax

View File

@ -69,14 +69,11 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
* the status of the test failing. No tests will run.
*/
public final static List<String> EXECUTE_ONLY_MODELS = Arrays.asList(
// "max_pool_with_argmax/int32_int64_padding_SAME",
"max_pool_with_argmax/int32_int64_padding_SAME",
// "fused_batch_norm/float32_nhwc",
// "max_pool_with_argmax/int64_int64_padding_SAME",
"max_pool_with_argmax/int64_int64_padding_SAME"
// "fused_batch_norm/float16_nhwc",
"roll/rank3_int32_axis",
"roll/rank3_int32_axis",
"roll/rank2_float32_zeroshift",
"roll/rank3_float64_axis"
);
public static final String[] IGNORE_REGEXES = new String[]{

View File

@ -1,2 +1,3 @@
in_0/read,in_0/read
Roll,Roll
MaxPoolWithArgmax,MaxPoolWithArgmax
MaxPoolWithArgmax:1,MaxPoolWithArgmax

View File

@ -99,13 +99,6 @@ class TestTensorflowIR {
val output2 = importedGraph.outputAll(inputMap)
val matrix =
TensorflowIRTensor(tensorflowIRGraph.nodeByName("in_0").attrMap["value"]!!.tensor).toNd4jNDArray()
val roll2 = Roll(matrix, Nd4j.scalar(2), Nd4j.scalar(1))
val outputs = Nd4j.exec(roll2)[0]
val tfOutputRoll = tfOutput["Roll"]
val nd4jOutput = output["Roll"]
//assertEquals(tfOutput.keys,outputList)
//assertEquals(tfOutput.keys,output2.keys)
val names = tensorflowIRGraph.nodeList().map { input -> input.nodeName() }