diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/concat.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/concat.cpp index 9df63556e..3bf97e586 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/concat.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/concat.cpp @@ -178,7 +178,11 @@ PLATFORM_CHECK(concat, ENGINE_CPU) { const auto zType = z->dataType(); - return z->rankOf() < 7 && (zType==DataType::FLOAT32 || zType==DataType::HALF || zType==DataType::BFLOAT16 || zType==DataType::UINT8 || zType==DataType::INT8); + const bool isAxisInLastArr = block.getBArguments()->size() == 0 ? false : B_ARG(0); + const int numOfInArrs = isAxisInLastArr ? block.width() - 1 : block.width(); + + return z->rankOf() < 7 && numOfInArrs <= 3072 + && (zType==DataType::FLOAT32 || zType==DataType::HALF || zType==DataType::BFLOAT16 || zType==DataType::UINT8 || zType==DataType::INT8); } }