Update allowed data types in pooling
parent
53bfdb9994
commit
968eaad2dd
|
@ -218,7 +218,7 @@ namespace sd {
|
|||
}
|
||||
|
||||
void ConvolutionUtils::pooling2d(sd::graph::Context& block, const NDArray& input, NDArray& output, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const PoolingType poolingMode, const int extraParam0) {
|
||||
BUILD_SINGLE_SELECTOR(input.dataType(), pooling2d_, (block, input, output, kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0), FLOAT_TYPES);
|
||||
BUILD_SINGLE_SELECTOR(input.dataType(), pooling2d_, (block, input, output, kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0), NUMERIC_TYPES);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -301,7 +301,7 @@ namespace sd {
|
|||
}
|
||||
|
||||
void ConvolutionUtils::pooling2dBP(sd::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0) {
|
||||
BUILD_SINGLE_SELECTOR(input.dataType(), pooling2dBP_, (block, input, gradO, gradI, kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0), FLOAT_TYPES);
|
||||
BUILD_SINGLE_SELECTOR(input.dataType(), pooling2dBP_, (block, input, gradO, gradI, kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0), NUMERIC_TYPES);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -315,11 +315,11 @@ void ConvolutionUtils::pooling2d(sd::graph::Context& block, const NDArray& input
|
|||
switch (poolingMode) {
|
||||
|
||||
case MAX_POOL: {
|
||||
BUILD_SINGLE_SELECTOR_TWICE(input.dataType(), maxPooling2dCudaLauncher, (*block.launchContext(), input.specialBuffer(), input.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), kH, kW, sH, sW, pH, pW, dH, dW, extraParam0), FLOAT_TYPES);
|
||||
BUILD_SINGLE_SELECTOR_TWICE(input.dataType(), maxPooling2dCudaLauncher, (*block.launchContext(), input.specialBuffer(), input.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), kH, kW, sH, sW, pH, pW, dH, dW, extraParam0), NUMERIC_TYPES);
|
||||
}
|
||||
break;
|
||||
case AVG_POOL: {
|
||||
BUILD_SINGLE_SELECTOR_TWICE(input.dataType(), avgPooling2dCudaLauncher, (*block.launchContext(), input.specialBuffer(), input.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), kH, kW, sH, sW, pH, pW, dH, dW, extraParam0), FLOAT_TYPES);
|
||||
BUILD_SINGLE_SELECTOR_TWICE(input.dataType(), avgPooling2dCudaLauncher, (*block.launchContext(), input.specialBuffer(), input.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), kH, kW, sH, sW, pH, pW, dH, dW, extraParam0), NUMERIC_TYPES);
|
||||
}
|
||||
break;
|
||||
case PNORM_POOL: {
|
||||
|
|
|
@ -178,7 +178,7 @@ void ConvolutionUtils::pooling2dBP(sd::graph::Context& block, const NDArray& inp
|
|||
const int sharedMem = gradO.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128;
|
||||
|
||||
NDArray::prepareSpecialUse({&gradI}, {&input, &gradO});
|
||||
BUILD_SINGLE_SELECTOR(input.dataType(), pooling2dBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), input.specialBuffer(), input.specialShapeInfo(), gradO.specialBuffer(), gradO.specialShapeInfo(), gradI.specialBuffer(), gradI.specialShapeInfo(), kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0), FLOAT_TYPES);
|
||||
BUILD_SINGLE_SELECTOR(input.dataType(), pooling2dBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), input.specialBuffer(), input.specialShapeInfo(), gradO.specialBuffer(), gradO.specialShapeInfo(), gradI.specialBuffer(), gradI.specialShapeInfo(), kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0), NUMERIC_TYPES);
|
||||
NDArray::registerSpecialUse({&gradI}, {&input, &gradO});
|
||||
|
||||
manager.synchronize();
|
||||
|
|
|
@ -1,5 +1,2 @@
|
|||
Identity,Variable/read
|
||||
Identity,Variable_1/read
|
||||
Pack,floordiv/x
|
||||
Pack,floordiv/y
|
||||
FloorDiv,floordiv
|
||||
Identity,in_0/read
|
||||
MaxPoolWithArgmax,MaxPoolWithArgmax
|
||||
|
|
|
@ -69,14 +69,11 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
|
|||
* the status of the test failing. No tests will run.
|
||||
*/
|
||||
public final static List<String> EXECUTE_ONLY_MODELS = Arrays.asList(
|
||||
// "max_pool_with_argmax/int32_int64_padding_SAME",
|
||||
"max_pool_with_argmax/int32_int64_padding_SAME",
|
||||
// "fused_batch_norm/float32_nhwc",
|
||||
// "max_pool_with_argmax/int64_int64_padding_SAME",
|
||||
"max_pool_with_argmax/int64_int64_padding_SAME"
|
||||
// "fused_batch_norm/float16_nhwc",
|
||||
"roll/rank3_int32_axis",
|
||||
"roll/rank3_int32_axis",
|
||||
"roll/rank2_float32_zeroshift",
|
||||
"roll/rank3_float64_axis"
|
||||
|
||||
);
|
||||
|
||||
public static final String[] IGNORE_REGEXES = new String[]{
|
||||
|
|
|
@ -1,2 +1,3 @@
|
|||
in_0/read,in_0/read
|
||||
Roll,Roll
|
||||
MaxPoolWithArgmax,MaxPoolWithArgmax
|
||||
MaxPoolWithArgmax:1,MaxPoolWithArgmax
|
||||
|
|
|
@ -99,13 +99,6 @@ class TestTensorflowIR {
|
|||
val output2 = importedGraph.outputAll(inputMap)
|
||||
|
||||
|
||||
val matrix =
|
||||
TensorflowIRTensor(tensorflowIRGraph.nodeByName("in_0").attrMap["value"]!!.tensor).toNd4jNDArray()
|
||||
val roll2 = Roll(matrix, Nd4j.scalar(2), Nd4j.scalar(1))
|
||||
val outputs = Nd4j.exec(roll2)[0]
|
||||
val tfOutputRoll = tfOutput["Roll"]
|
||||
val nd4jOutput = output["Roll"]
|
||||
|
||||
//assertEquals(tfOutput.keys,outputList)
|
||||
//assertEquals(tfOutput.keys,output2.keys)
|
||||
val names = tensorflowIRGraph.nodeList().map { input -> input.nodeName() }
|
||||
|
|
Loading…
Reference in New Issue