Shyrma transpose (#244)
* - provide contiguous strides for ouput in transpose op Signed-off-by: Yurii <iuriish@yahoo.com> * - provide contiguous strides for output in permute op Signed-off-by: Yurii <iuriish@yahoo.com> * - take into account empty shapes properly in transpose/permute op Signed-off-by: Yurii <iuriish@yahoo.com>
This commit is contained in:
		
							parent
							
								
									9e3c1b02b1
								
							
						
					
					
						commit
						011c272fde
					
				@ -1976,7 +1976,7 @@ bool NDArray::permutei(const std::initializer_list<int>& dimensions) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
//////////////////////////////////////////////////////////////////////////
 | 
					//////////////////////////////////////////////////////////////////////////
 | 
				
			||||||
bool NDArray::permutei(const std::vector<int>& dimensions) {
 | 
					bool NDArray::permutei(const std::vector<int>& dimensions) {
 | 
				
			||||||
    return permutei(dimensions.data(), dimensions.size());
 | 
					    return permutei(dimensions.data(), rankOf());
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
//////////////////////////////////////////////////////////////////////////
 | 
					//////////////////////////////////////////////////////////////////////////
 | 
				
			||||||
@ -1998,7 +1998,7 @@ bool NDArray::permutei(const std::vector<Nd4jLong>& dimensions) {
 | 
				
			|||||||
    for (int e = 0; e < dimensions.size(); e++)
 | 
					    for (int e = 0; e < dimensions.size(); e++)
 | 
				
			||||||
        ivec[e] = dimensions[e];
 | 
					        ivec[e] = dimensions[e];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return permutei(ivec.data(), ivec.size());
 | 
					    return permutei(ivec.data(), rankOf());
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
//////////////////////////////////////////////////////////////////////////
 | 
					//////////////////////////////////////////////////////////////////////////
 | 
				
			||||||
@ -2034,9 +2034,8 @@ NDArray NDArray::permute(const Nd4jLong* dimensions, const int rank) && {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
//////////////////////////////////////////////////////////////////////////
 | 
					//////////////////////////////////////////////////////////////////////////
 | 
				
			||||||
NDArray NDArray::permute(const std::vector<int>& dimensions) const &{
 | 
					NDArray NDArray::permute(const std::vector<int>& dimensions) const &{
 | 
				
			||||||
    auto data = dimensions.data();
 | 
					
 | 
				
			||||||
    auto size = dimensions.size();
 | 
					    return permute(dimensions.data(), rankOf());
 | 
				
			||||||
    return permute(data, size);
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
//////////////////////////////////////////////////////////////////////////
 | 
					//////////////////////////////////////////////////////////////////////////
 | 
				
			||||||
@ -2048,7 +2047,8 @@ NDArray NDArray::permute(const std::vector<int>& dimensions) && {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
//////////////////////////////////////////////////////////////////////////
 | 
					//////////////////////////////////////////////////////////////////////////
 | 
				
			||||||
NDArray NDArray::permute(const std::vector<Nd4jLong>& dimensions) const & {
 | 
					NDArray NDArray::permute(const std::vector<Nd4jLong>& dimensions) const & {
 | 
				
			||||||
    return permute(dimensions.data(), dimensions.size());
 | 
					
 | 
				
			||||||
 | 
					    return permute(dimensions.data(), rankOf());
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
//////////////////////////////////////////////////////////////////////////
 | 
					//////////////////////////////////////////////////////////////////////////
 | 
				
			||||||
@ -2111,12 +2111,12 @@ void NDArray::permute(const Nd4jLong *dimensions, const int rank, NDArray& targe
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
//////////////////////////////////////////////////////////////////////////
 | 
					//////////////////////////////////////////////////////////////////////////
 | 
				
			||||||
void NDArray::permute(const std::vector<int>& dimensions, NDArray& target) const {
 | 
					void NDArray::permute(const std::vector<int>& dimensions, NDArray& target) const {
 | 
				
			||||||
    permute(dimensions.data(), dimensions.size(), target);
 | 
					    permute(dimensions.data(), rankOf(), target);
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
//////////////////////////////////////////////////////////////////////////
 | 
					//////////////////////////////////////////////////////////////////////////
 | 
				
			||||||
void NDArray::permute(const std::vector<Nd4jLong>& dimensions, NDArray& target) const {
 | 
					void NDArray::permute(const std::vector<Nd4jLong>& dimensions, NDArray& target) const {
 | 
				
			||||||
    permute(dimensions.data(), dimensions.size(), target);
 | 
					    permute(dimensions.data(), rankOf(), target);
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
//////////////////////////////////////////////////////////////////////////
 | 
					//////////////////////////////////////////////////////////////////////////
 | 
				
			||||||
 | 
				
			|||||||
@ -50,11 +50,13 @@ namespace nd4j {
 | 
				
			|||||||
    	static std::vector<Nd4jLong> evalRepeatShape(int axis, const std::vector<int>& repeats, const NDArray& arr);
 | 
					    	static std::vector<Nd4jLong> evalRepeatShape(int axis, const std::vector<int>& repeats, const NDArray& arr);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        // evaluate shapeInfo of permuted array
 | 
					        // evaluate shapeInfo of permuted array
 | 
				
			||||||
        static Nd4jLong* evalPermShapeInfo(const int* dimensions, const int rank, const NDArray& arr, nd4j::memory::Workspace* workspace);
 | 
					        // if setContigStrides = true, then set contiguous strides in output shapeInfo in accordance with arr order
 | 
				
			||||||
 | 
					        static Nd4jLong* evalPermShapeInfo(const int* dimensions, const int rank, const NDArray& arr, nd4j::memory::Workspace* workspace, const bool setContigStrides = false);
 | 
				
			||||||
        static Nd4jLong* evalPermShapeInfo(const Nd4jLong* dimensions, const int rank, const NDArray& arr, nd4j::memory::Workspace* workspace);
 | 
					        static Nd4jLong* evalPermShapeInfo(const Nd4jLong* dimensions, const int rank, const NDArray& arr, nd4j::memory::Workspace* workspace);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        // evaluate shapeInfo of transposed array
 | 
					        // evaluate shapeInfo of transposed array
 | 
				
			||||||
        static Nd4jLong* evalTranspShapeInfo(const NDArray& arr, nd4j::memory::Workspace* workspace);
 | 
					        // if setContigStrides = true, then set contiguous strides in output shapeInfo in accordance with arr order
 | 
				
			||||||
 | 
					        static Nd4jLong* evalTranspShapeInfo(const NDArray& arr, nd4j::memory::Workspace* workspace, const bool setContigStrides = false);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        static bool copyVectorPart(std::vector<int>& target, std::vector<int>& source, int rank, int offset);
 | 
					        static bool copyVectorPart(std::vector<int>& target, std::vector<int>& source, int rank, int offset);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -313,32 +313,37 @@ std::vector<Nd4jLong> ShapeUtils::evalRepeatShape(int axis, const std::vector<in
 | 
				
			|||||||
    return outShape;
 | 
					    return outShape;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
//////////////////////////////////////////////////////////////////////////
 | 
					//////////////////////////////////////////////////////////////////////////
 | 
				
			||||||
// evaluate shapeInfo of permuted array
 | 
					// evaluate shapeInfo of permuted array
 | 
				
			||||||
    Nd4jLong* ShapeUtils::evalPermShapeInfo(const int* dimensions, const int rank, const NDArray& arr, nd4j::memory::Workspace* workspace) {
 | 
					Nd4jLong* ShapeUtils::evalPermShapeInfo(const int* dimensions, const int rank, const NDArray& arr, nd4j::memory::Workspace* workspace, const bool setContigStrides) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if (!arr.nonNull())
 | 
					    if (!arr.nonNull())
 | 
				
			||||||
            throw std::runtime_error("ShapeUtils::evalPermShapeInfo static method: wrong arguments: array is nullptr!");
 | 
					        throw std::runtime_error("ShapeUtils::evalPermShapeInfo static method: wrong arguments: array is nullptr!");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if (rank != arr.rankOf())
 | 
					    if (rank != arr.rankOf())
 | 
				
			||||||
            throw std::runtime_error("ShapeUtils::evalPermShapeInfo static method: wrong arguments: rank is not suitable!");
 | 
					        throw std::runtime_error("ShapeUtils::evalPermShapeInfo static method: wrong arguments: rank is not suitable!");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        auto shapeInfoLength = shape::shapeInfoLength(rank);
 | 
					    auto shapeInfoLength = shape::shapeInfoLength(rank);
 | 
				
			||||||
        // allocate memory for new array - shapeInfo
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        Nd4jLong *shapeInfoNew = nullptr;
 | 
					    // allocate memory for new array - shapeInfo
 | 
				
			||||||
        ALLOCATE(shapeInfoNew, workspace, shapeInfoLength, Nd4jLong);
 | 
					    Nd4jLong *shapeInfoNew = nullptr;
 | 
				
			||||||
        // copy arr _shapeInfo into new array
 | 
					    ALLOCATE(shapeInfoNew, workspace, shapeInfoLength, Nd4jLong);
 | 
				
			||||||
        memcpy(shapeInfoNew, arr.getShapeInfo(), shape::shapeInfoByteLength(rank));
 | 
					 | 
				
			||||||
        // perform buffer permutation
 | 
					 | 
				
			||||||
        shape::doPermuteShapeInfo(shapeInfoNew, dimensions);
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        ShapeDescriptor descriptor(shapeInfoNew);
 | 
					    // copy arr _shapeInfo into new array
 | 
				
			||||||
        RELEASE(shapeInfoNew, workspace);
 | 
					    memcpy(shapeInfoNew, arr.getShapeInfo(), shape::shapeInfoByteLength(rank));
 | 
				
			||||||
        return ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // perform buffer permutation
 | 
				
			||||||
 | 
					    shape::doPermuteShapeInfo(shapeInfoNew, dimensions, arr.lengthOf());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if(setContigStrides)
 | 
				
			||||||
 | 
					        shape::updateStrides(shapeInfoNew, arr.ordering());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    ShapeDescriptor descriptor(shapeInfoNew);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    RELEASE(shapeInfoNew, workspace);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    //////////////////////////////////////////////////////////////////////////
 | 
					    //////////////////////////////////////////////////////////////////////////
 | 
				
			||||||
    // evaluate shapeInfo of permuted array
 | 
					    // evaluate shapeInfo of permuted array
 | 
				
			||||||
@ -350,14 +355,14 @@ std::vector<Nd4jLong> ShapeUtils::evalRepeatShape(int axis, const std::vector<in
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
//////////////////////////////////////////////////////////////////////////
 | 
					//////////////////////////////////////////////////////////////////////////
 | 
				
			||||||
// evaluate shapeInfo of transposed array
 | 
					// evaluate shapeInfo of transposed array
 | 
				
			||||||
    Nd4jLong* ShapeUtils::evalTranspShapeInfo(const NDArray& arr, nd4j::memory::Workspace* workspace) {
 | 
					    Nd4jLong* ShapeUtils::evalTranspShapeInfo(const NDArray& arr, nd4j::memory::Workspace* workspace, const bool setContigStrides) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        int rank = arr.rankOf();
 | 
					        int rank = arr.rankOf();
 | 
				
			||||||
        std::vector<int> dimensions(rank);
 | 
					        std::vector<int> dimensions(rank);
 | 
				
			||||||
        for (int i = 0; i < rank; ++i)
 | 
					        for (int i = 0; i < rank; ++i)
 | 
				
			||||||
            dimensions[i] = rank - 1 - i;
 | 
					            dimensions[i] = rank - 1 - i;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return evalPermShapeInfo(dimensions.data(), dimensions.size(), arr, workspace);
 | 
					        return evalPermShapeInfo(dimensions.data(), dimensions.size(), arr, workspace, setContigStrides);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
//////////////////////////////////////////////////////////////////////////
 | 
					//////////////////////////////////////////////////////////////////////////
 | 
				
			||||||
 | 
				
			|||||||
@ -15,7 +15,8 @@
 | 
				
			|||||||
 ******************************************************************************/
 | 
					 ******************************************************************************/
 | 
				
			||||||
 | 
					
 | 
				
			||||||
//
 | 
					//
 | 
				
			||||||
// Created by raver119 on 29/10/17.
 | 
					// @author raver119@gmail.com
 | 
				
			||||||
 | 
					// @author Yurii Shyrma (iuriish@yahoo.com)
 | 
				
			||||||
//
 | 
					//
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#include <op_boilerplate.h>
 | 
					#include <op_boilerplate.h>
 | 
				
			||||||
@ -29,80 +30,52 @@ namespace nd4j {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
//////////////////////////////////////////////////////////////////////////
 | 
					//////////////////////////////////////////////////////////////////////////
 | 
				
			||||||
// here iArgs is int vector of ordered set of dimensions to be permuted
 | 
					// here iArgs is int vector of ordered set of dimensions to be permuted
 | 
				
			||||||
        CUSTOM_OP_IMPL(permute, 1, 1, true, 0, -2) {
 | 
					CUSTOM_OP_IMPL(permute, 1, 1, true, 0, -2) {
 | 
				
			||||||
            auto x = INPUT_VARIABLE(0);
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
            bool replace = false;
 | 
					    auto x = INPUT_VARIABLE(0);
 | 
				
			||||||
 | 
					    auto z = OUTPUT_VARIABLE(0);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            auto origArgs = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT<int>() : *block.getIArguments();
 | 
					    if (x->isEmpty()) {
 | 
				
			||||||
            std::vector<int> arguments({});
 | 
					        REQUIRE_TRUE(z->isEmpty(), 0, "PERMUTE OP: when input is empty, output must also be empty");
 | 
				
			||||||
            if(origArgs.size() > 0){
 | 
					        return Status::OK();    //No op
 | 
				
			||||||
                for (int e = 0; e < origArgs.size(); e++) {
 | 
					 | 
				
			||||||
                    int ax = origArgs[e];
 | 
					 | 
				
			||||||
                    if (ax < 0)
 | 
					 | 
				
			||||||
                        ax += x->rankOf();
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                    arguments.emplace_back(ax);
 | 
					 | 
				
			||||||
                }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                replace = true;
 | 
					 | 
				
			||||||
            } else {
 | 
					 | 
				
			||||||
                for (int e = x->rankOf() - 1; e >= 0; e--)
 | 
					 | 
				
			||||||
                    arguments.emplace_back(e);
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            // 0D edge case
 | 
					 | 
				
			||||||
            if (x->rankOf() == 0) {
 | 
					 | 
				
			||||||
                REQUIRE_TRUE(arguments.size() == 1, 0, "Permute: only one axis is allowed for scalar");
 | 
					 | 
				
			||||||
                auto output = OUTPUT_VARIABLE(0);
 | 
					 | 
				
			||||||
                if (!block.isInplace())
 | 
					 | 
				
			||||||
                    output->assign(x);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                return Status::OK();
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            if(block.isInplace()) {		// in-place
 | 
					 | 
				
			||||||
                x->permutei(arguments);
 | 
					 | 
				
			||||||
                STORE_RESULT(x);
 | 
					 | 
				
			||||||
            } else {
 | 
					 | 
				
			||||||
                auto output = OUTPUT_VARIABLE(0);
 | 
					 | 
				
			||||||
                auto result = x->permute(arguments);
 | 
					 | 
				
			||||||
                output->assign(result);
 | 
					 | 
				
			||||||
                STORE_RESULT(output);
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            return Status::OK();
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        DECLARE_TYPES(permute) {
 | 
					 | 
				
			||||||
            getOpDescriptor()
 | 
					 | 
				
			||||||
                    ->setAllowedInputTypes(0, nd4j::DataType::ANY)
 | 
					 | 
				
			||||||
                    ->setAllowedInputTypes(1, {ALL_INTS})
 | 
					 | 
				
			||||||
                    ->setSameMode(true);
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        DECLARE_SHAPE_FN(permute) {
 | 
					 | 
				
			||||||
            auto shapeList = SHAPELIST();
 | 
					 | 
				
			||||||
            auto arguments = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT<int>() : *block.getIArguments();
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            if (shape::rank(inputShape->at(0)) == 0) {
 | 
					 | 
				
			||||||
                shapeList->push_back(ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inputShape->at(0))));
 | 
					 | 
				
			||||||
            } else if (inputShape->size() == 1 && !arguments.empty()) {
 | 
					 | 
				
			||||||
                shapeList->push_back(ShapeUtils::evalPermShapeInfo(arguments.data(), arguments.size(), *INPUT_VARIABLE(0), block.workspace()));
 | 
					 | 
				
			||||||
            } else {
 | 
					 | 
				
			||||||
                if(arguments.size() == 0){
 | 
					 | 
				
			||||||
                    //Reverse dimensions
 | 
					 | 
				
			||||||
                    int rank = shape::rank(inputShape->at(0));
 | 
					 | 
				
			||||||
                    for (int e = rank - 1; e >= 0; e--)
 | 
					 | 
				
			||||||
                        arguments.emplace_back(e);
 | 
					 | 
				
			||||||
                }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                shapeList->push_back(ShapeUtils::evalPermShapeInfo(arguments.data(), arguments.size(), *INPUT_VARIABLE(0), block.workspace()));
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
    
 | 
					 | 
				
			||||||
            return shapeList;
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if (block.width() == 1 && block.getIArguments()->size() == 0) {
 | 
				
			||||||
 | 
					        z->assign(x->transpose());
 | 
				
			||||||
 | 
					        return Status::OK();
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    std::vector<int> permutationVector = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT<int>() : *block.getIArguments();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    z->assign(x->permute(permutationVector));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return Status::OK();
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					//////////////////////////////////////////////////////////////////////////
 | 
				
			||||||
 | 
					DECLARE_TYPES(permute) {
 | 
				
			||||||
 | 
					    getOpDescriptor()
 | 
				
			||||||
 | 
					            ->setAllowedInputTypes(0, nd4j::DataType::ANY)
 | 
				
			||||||
 | 
					            ->setAllowedInputTypes(1, {ALL_INTS})
 | 
				
			||||||
 | 
					            ->setSameMode(true);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					//////////////////////////////////////////////////////////////////////////
 | 
				
			||||||
 | 
					DECLARE_SHAPE_FN(permute) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    auto x = INPUT_VARIABLE(0);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if (block.width() == 1 && block.getIArguments()->size() == 0)
 | 
				
			||||||
 | 
					        return SHAPELIST(ShapeUtils::evalTranspShapeInfo(*x, block.workspace(), true));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    std::vector<int> permutationVector = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT<int>() : *block.getIArguments();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    auto outputShapeInfo = ShapeUtils::evalPermShapeInfo(permutationVector.data(), x->rankOf(), *x, block.workspace(), true);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return SHAPELIST(outputShapeInfo);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#endif
 | 
					#endif
 | 
				
			||||||
@ -24,254 +24,240 @@
 | 
				
			|||||||
#include <ops/declarable/CustomOperations.h>
 | 
					#include <ops/declarable/CustomOperations.h>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
namespace nd4j {
 | 
					namespace nd4j {
 | 
				
			||||||
    namespace ops {
 | 
					namespace ops  {
 | 
				
			||||||
        //////////////////////////////////////////////////////////////////////////
 | 
					 | 
				
			||||||
        // here iArgs is a vector with (optional) negative of order as first element:
 | 
					 | 
				
			||||||
        // ({-order, dim1, dim2, dim3, ...})
 | 
					 | 
				
			||||||
        CUSTOM_OP_IMPL(reshape, 1, 1, false, 0, -2) {
 | 
					 | 
				
			||||||
            auto x = INPUT_VARIABLE(0);
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
            if (block.width() == 1) {
 | 
					//////////////////////////////////////////////////////////////////////////
 | 
				
			||||||
                auto arguments = block.getIArguments();
 | 
					// here iArgs is a vector with (optional) negative of order as first element:
 | 
				
			||||||
                int argsSize = arguments->size();
 | 
					// ({-order, dim1, dim2, dim3, ...})
 | 
				
			||||||
 | 
					CUSTOM_OP_IMPL(reshape, 1, 1, false, 0, -2) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                //Special case: empty.reshape(<other empty shape>) -> return empty
 | 
					    auto x = INPUT_VARIABLE(0);
 | 
				
			||||||
                if (x->isEmpty()) {
 | 
					    auto z = OUTPUT_VARIABLE(0);
 | 
				
			||||||
                    REQUIRE_TRUE(OUTPUT_VARIABLE(0)->isEmpty(), 0, "Reshape: when input is empty, output must also be empty");
 | 
					
 | 
				
			||||||
                    return ND4J_STATUS_OK;    //No op
 | 
					    //Special case: empty.reshape(<other empty shape>) -> return empty
 | 
				
			||||||
 | 
					        if (x->isEmpty()) {
 | 
				
			||||||
 | 
					            REQUIRE_TRUE(z->isEmpty(), 0, "Reshape: when input is empty, output must also be empty");
 | 
				
			||||||
 | 
					            return Status::OK();    //No op
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if (block.width() == 1) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        auto arguments = block.getIArguments();
 | 
				
			||||||
 | 
					        int argsSize = arguments->size();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        int e = 1;
 | 
				
			||||||
 | 
					        char order = (char) -(*arguments)[0];
 | 
				
			||||||
 | 
					        if (order != 'c' && order != 'f') {
 | 
				
			||||||
 | 
					            order = 'c'; //x->ordering();
 | 
				
			||||||
 | 
					            e = 0;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        REQUIRE_TRUE(argsSize - e >= 1, 0, "Reshape arguments should have at least 1 dimension");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        std::vector<Nd4jLong> shapeNew;
 | 
				
			||||||
 | 
					        int e2 = e;
 | 
				
			||||||
 | 
					        for (; e < (int) arguments->size(); e++) {
 | 
				
			||||||
 | 
					            if (arguments->at(e) == -1){
 | 
				
			||||||
 | 
					                Nd4jLong shapeLength = 1;
 | 
				
			||||||
 | 
					                for(; e2 < e; e2++){
 | 
				
			||||||
 | 
					                    shapeLength *= arguments->at(e2);
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
 | 
					                for(e2 = e + 1; e2 < arguments->size(); e2++){
 | 
				
			||||||
                int e = 1;
 | 
					                    shapeLength *= arguments->at(e2);
 | 
				
			||||||
                char order = (char) -(*arguments)[0];
 | 
					 | 
				
			||||||
                if (order != 'c' && order != 'f') {
 | 
					 | 
				
			||||||
                    order = 'c'; //x->ordering();
 | 
					 | 
				
			||||||
                    e = 0;
 | 
					 | 
				
			||||||
                }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                REQUIRE_TRUE(argsSize - e >= 1, 0, "Reshape arguments should have at least 1 dimension");
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                std::vector<Nd4jLong> shapeNew;
 | 
					 | 
				
			||||||
                int e2 = e;
 | 
					 | 
				
			||||||
                for (; e < (int) arguments->size(); e++) {
 | 
					 | 
				
			||||||
                    if (arguments->at(e) == -1){
 | 
					 | 
				
			||||||
                        Nd4jLong shapeLength = 1;
 | 
					 | 
				
			||||||
                        for(; e2 < e; e2++){
 | 
					 | 
				
			||||||
                            shapeLength *= arguments->at(e2);
 | 
					 | 
				
			||||||
                        }
 | 
					 | 
				
			||||||
                        for(e2 = e + 1; e2 < arguments->size(); e2++){
 | 
					 | 
				
			||||||
                            shapeLength *= arguments->at(e2);
 | 
					 | 
				
			||||||
                        }
 | 
					 | 
				
			||||||
                        Nd4jLong realShape = x->lengthOf() / shapeLength;
 | 
					 | 
				
			||||||
                        shapeNew.push_back(realShape);
 | 
					 | 
				
			||||||
                    }
 | 
					 | 
				
			||||||
                    else{
 | 
					 | 
				
			||||||
                        shapeNew.push_back(arguments->at(e));
 | 
					 | 
				
			||||||
                    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                auto len = shape::prodLong(shapeNew.data(), shapeNew.size());
 | 
					 | 
				
			||||||
                REQUIRE_TRUE(len == x->lengthOf(), 0, "Reshape: lengths before and after reshape should match, but got %i vs %i", x->lengthOf(), len);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                if (Environment::getInstance()->isDebugAndVerbose()) {
 | 
					 | 
				
			||||||
                    nd4j_printv("Reshape: new shape", shapeNew);
 | 
					 | 
				
			||||||
                }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                if (block.isInplace()) {
 | 
					 | 
				
			||||||
                    if (x->reshapei(order, shapeNew)) {
 | 
					 | 
				
			||||||
                        STORE_RESULT(*x);
 | 
					 | 
				
			||||||
                        return ND4J_STATUS_OK;
 | 
					 | 
				
			||||||
                    }
 | 
					 | 
				
			||||||
                } else {
 | 
					 | 
				
			||||||
                    auto ret = OUTPUT_VARIABLE(0);
 | 
					 | 
				
			||||||
                    auto xr = x->reshape(order, shapeNew);
 | 
					 | 
				
			||||||
                    ret->assign(xr);
 | 
					 | 
				
			||||||
                    STORE_RESULT(*ret);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                    return Status::OK();
 | 
					 | 
				
			||||||
                }
 | 
					 | 
				
			||||||
            } else if (block.width() == 2) {
 | 
					 | 
				
			||||||
                auto s = INPUT_VARIABLE(1);
 | 
					 | 
				
			||||||
                
 | 
					 | 
				
			||||||
                //Special case: empty.reshape(-1) -> return empty
 | 
					 | 
				
			||||||
                if (x->isEmpty()) {
 | 
					 | 
				
			||||||
                    //REQUIRE_TRUE(s->lengthOf() == 1 && s->e<Nd4jLong>(0) == -1, 0, "Reshape: when input is empty, shape must be [-1]");
 | 
					 | 
				
			||||||
                    REQUIRE_TRUE(OUTPUT_VARIABLE(0)->isEmpty(), 0, "Reshape: when input is empty, output must also be empty");
 | 
					 | 
				
			||||||
                    return Status::OK();    //No op
 | 
					 | 
				
			||||||
                }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                char order = 'c';
 | 
					 | 
				
			||||||
                if (block.numI() > 0)
 | 
					 | 
				
			||||||
                    order = (char) -INT_ARG(0);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                std::vector<Nd4jLong> shapeNew(s->lengthOf());
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                for (int e = 0; e < (int) s->lengthOf(); e++) {
 | 
					 | 
				
			||||||
                    auto dim = s->e<Nd4jLong >(e);
 | 
					 | 
				
			||||||
                    if (dim == -1){
 | 
					 | 
				
			||||||
                        Nd4jLong shapeLength = 1;
 | 
					 | 
				
			||||||
                        for(int e2 = 0; e2 < e; e2++){
 | 
					 | 
				
			||||||
                            shapeLength *= s->e<Nd4jLong>(e2);
 | 
					 | 
				
			||||||
                        }
 | 
					 | 
				
			||||||
                        for(int e2 = e + 1; e2 < (int) s->lengthOf(); e2++){
 | 
					 | 
				
			||||||
                            REQUIRE_TRUE(s->e<Nd4jLong>(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
 | 
					 | 
				
			||||||
                            shapeLength *= s->e<Nd4jLong>(e2);
 | 
					 | 
				
			||||||
                        }
 | 
					 | 
				
			||||||
                        Nd4jLong realShape = x->lengthOf() / shapeLength;
 | 
					 | 
				
			||||||
                        shapeNew[e] = realShape;
 | 
					 | 
				
			||||||
                    }
 | 
					 | 
				
			||||||
                    else{
 | 
					 | 
				
			||||||
                        shapeNew[e] = dim;
 | 
					 | 
				
			||||||
                    }
 | 
					 | 
				
			||||||
                }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                if (Environment::getInstance()->isDebugAndVerbose()) {
 | 
					 | 
				
			||||||
                    nd4j_printv("Reshape: new shape", shapeNew);
 | 
					 | 
				
			||||||
                }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                if (block.isInplace()) {
 | 
					 | 
				
			||||||
                    if (x->reshapei(order, shapeNew)) {
 | 
					 | 
				
			||||||
                        STORE_RESULT(*x);
 | 
					 | 
				
			||||||
                        return Status::OK();
 | 
					 | 
				
			||||||
                    }
 | 
					 | 
				
			||||||
                } else {
 | 
					 | 
				
			||||||
                    auto ret = OUTPUT_VARIABLE(0);
 | 
					 | 
				
			||||||
                    if (s->isEmpty()) {
 | 
					 | 
				
			||||||
                        // just a scalar
 | 
					 | 
				
			||||||
                        ret->assign(x);
 | 
					 | 
				
			||||||
                    } else {
 | 
					 | 
				
			||||||
                        auto xr = x->reshape(order, shapeNew);
 | 
					 | 
				
			||||||
                        ret->assign(xr);
 | 
					 | 
				
			||||||
                    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                    return Status::OK();
 | 
					 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
 | 
					                Nd4jLong realShape = x->lengthOf() / shapeLength;
 | 
				
			||||||
 | 
					                shapeNew.push_back(realShape);
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					            else{
 | 
				
			||||||
 | 
					                shapeNew.push_back(arguments->at(e));
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            return ND4J_STATUS_BAD_INPUT;
 | 
					 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        auto len = shape::prodLong(shapeNew.data(), shapeNew.size());
 | 
				
			||||||
 | 
					        REQUIRE_TRUE(len == x->lengthOf(), 0, "Reshape: lengths before and after reshape should match, but got %i vs %i", x->lengthOf(), len);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        DECLARE_TYPES(reshape) {
 | 
					        if (Environment::getInstance()->isDebugAndVerbose()) {
 | 
				
			||||||
            getOpDescriptor()
 | 
					            nd4j_printv("Reshape: new shape", shapeNew);
 | 
				
			||||||
                    ->setAllowedInputTypes(0, nd4j::DataType::ANY)
 | 
					 | 
				
			||||||
                    ->setAllowedInputTypes(1, {ALL_INTS})
 | 
					 | 
				
			||||||
                    ->setSameMode(true);
 | 
					 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        DECLARE_SHAPE_FN(reshape) {
 | 
					        auto xr = x->reshape(order, shapeNew);
 | 
				
			||||||
            auto inp = inputShape->at(0);
 | 
					        z->assign(xr);
 | 
				
			||||||
 | 
					        STORE_RESULT(*z);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            // we can launch op using Int arguments
 | 
					        return Status::OK();
 | 
				
			||||||
            if (inputShape->size() == 1) {
 | 
					 | 
				
			||||||
                REQUIRE_TRUE(block.numI() > 0, 0, "Reshape: new shape should be provided as NDArray or int arguments, but nothing was defined");
 | 
					 | 
				
			||||||
                std::vector<int> *arguments = block.getIArguments();
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
                int e = 1;
 | 
					    } else if (block.width() == 2) {
 | 
				
			||||||
                char order = (char) -(*arguments)[0];
 | 
					
 | 
				
			||||||
                if (order != 'c' && order != 'f') {
 | 
					        auto s = INPUT_VARIABLE(1);
 | 
				
			||||||
                    order = shape::order(inp);
 | 
					
 | 
				
			||||||
                    e = 0;
 | 
					        char order = 'c';
 | 
				
			||||||
 | 
					        if (block.numI() > 0)
 | 
				
			||||||
 | 
					            order = (char) -INT_ARG(0);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        std::vector<Nd4jLong> shapeNew(s->lengthOf());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for (int e = 0; e < (int) s->lengthOf(); e++) {
 | 
				
			||||||
 | 
					            auto dim = s->e<Nd4jLong >(e);
 | 
				
			||||||
 | 
					            if (dim == -1){
 | 
				
			||||||
 | 
					                Nd4jLong shapeLength = 1;
 | 
				
			||||||
 | 
					                for(int e2 = 0; e2 < e; e2++){
 | 
				
			||||||
 | 
					                    shapeLength *= s->e<Nd4jLong>(e2);
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
 | 
					                for(int e2 = e + 1; e2 < (int) s->lengthOf(); e2++){
 | 
				
			||||||
                std::vector<Nd4jLong> shapeNew;
 | 
					                    REQUIRE_TRUE(s->e<Nd4jLong>(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
 | 
				
			||||||
 | 
					                    shapeLength *= s->e<Nd4jLong>(e2);
 | 
				
			||||||
                int e2 = e;
 | 
					 | 
				
			||||||
                for (; e < (int) arguments->size(); e++) {
 | 
					 | 
				
			||||||
                    if ((int) arguments->at(e) == -1){
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                        Nd4jLong shapeLength = 1;
 | 
					 | 
				
			||||||
                        for(; e2 < e; e2 ++){
 | 
					 | 
				
			||||||
                            shapeLength *= arguments->at(e2);
 | 
					 | 
				
			||||||
                        }
 | 
					 | 
				
			||||||
                        for(e2 = e + 1; e2 < arguments->size(); e2++){
 | 
					 | 
				
			||||||
                            REQUIRE_TRUE(arguments->at(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
 | 
					 | 
				
			||||||
                            shapeLength *= arguments->at(e2);
 | 
					 | 
				
			||||||
                        }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                        if(shapeLength == 0){
 | 
					 | 
				
			||||||
                            //Edge case for empty:
 | 
					 | 
				
			||||||
                            shapeNew.push_back(0);
 | 
					 | 
				
			||||||
                        } else {
 | 
					 | 
				
			||||||
                            //Standard case
 | 
					 | 
				
			||||||
                            Nd4jLong realShape = shape::length(inp) / shapeLength;
 | 
					 | 
				
			||||||
                            shapeNew.push_back(realShape);
 | 
					 | 
				
			||||||
                        }
 | 
					 | 
				
			||||||
                    }
 | 
					 | 
				
			||||||
                    else{
 | 
					 | 
				
			||||||
                        shapeNew.push_back(arguments->at(e));
 | 
					 | 
				
			||||||
                    }
 | 
					 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
 | 
					                Nd4jLong realShape = x->lengthOf() / shapeLength;
 | 
				
			||||||
                return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(inp), order, shapeNew)));
 | 
					                shapeNew[e] = realShape;
 | 
				
			||||||
            } else {
 | 
					            }
 | 
				
			||||||
                // or, with second input "as shape"
 | 
					            else{
 | 
				
			||||||
                auto x = INPUT_VARIABLE(0);
 | 
					                shapeNew[e] = dim;
 | 
				
			||||||
                auto y = INPUT_VARIABLE(1);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                // special case here
 | 
					 | 
				
			||||||
                if (y->isEmpty()) {
 | 
					 | 
				
			||||||
                    REQUIRE_TRUE(x->lengthOf() == 1, 0, "Reshape: new length doesn't match existing array");
 | 
					 | 
				
			||||||
                    return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inp)));
 | 
					 | 
				
			||||||
                }
 | 
					 | 
				
			||||||
                //Special case: empty.reshape(-1) -> return empty
 | 
					 | 
				
			||||||
                if (x->isEmpty()) {
 | 
					 | 
				
			||||||
                    //REQUIRE_TRUE(y->lengthOf() == 1 && y->e<Nd4jLong>(0) == -1, 0, "Reshape: when input is empty, shape must be [-1]");
 | 
					 | 
				
			||||||
                    auto shapeOf = y->getBufferAsVector<Nd4jLong>();
 | 
					 | 
				
			||||||
                    Nd4jLong prod = 1;
 | 
					 | 
				
			||||||
                    bool hasNegs = false;
 | 
					 | 
				
			||||||
                    for (auto v:shapeOf) {
 | 
					 | 
				
			||||||
                        if (v < 0) {
 | 
					 | 
				
			||||||
                            hasNegs = true;
 | 
					 | 
				
			||||||
                            v = 0;
 | 
					 | 
				
			||||||
                        }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                        prod *= v;
 | 
					 | 
				
			||||||
                    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                    REQUIRE_TRUE(prod == 0, 0, "Reshape: in case of empty arrays reshape must return empty array as well");
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                    // if there are -1s - we turn them into zeros
 | 
					 | 
				
			||||||
                    if (hasNegs) {
 | 
					 | 
				
			||||||
                        for (int e = 0; e < shapeOf.size(); e++)
 | 
					 | 
				
			||||||
                            if (shapeOf[e] < 0)
 | 
					 | 
				
			||||||
                                shapeOf[e] = 0;
 | 
					 | 
				
			||||||
                    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                    auto newShape = ShapeBuilders::createShapeInfo(ArrayOptions::dataType(inp), shape::order(inp), y->lengthOf(), shapeOf.data());
 | 
					 | 
				
			||||||
                    return SHAPELIST(CONSTANT(newShape));
 | 
					 | 
				
			||||||
                }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                std::vector<Nd4jLong> shapeNew(y->lengthOf());
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                for (int e = 0; e < (int) y->lengthOf(); e++) {
 | 
					 | 
				
			||||||
                    auto dim = y->e<Nd4jLong>(e);
 | 
					 | 
				
			||||||
                    if (dim == -1){
 | 
					 | 
				
			||||||
                        Nd4jLong shapeLength = 1;
 | 
					 | 
				
			||||||
                        for(int e2 = 0; e2 < e; e2++){
 | 
					 | 
				
			||||||
                            shapeLength *= y->e<Nd4jLong>(e2);
 | 
					 | 
				
			||||||
                        }
 | 
					 | 
				
			||||||
                        for(int e2 = e + 1; e2 < (int)y->lengthOf(); e2++){
 | 
					 | 
				
			||||||
                            REQUIRE_TRUE(y->e<Nd4jLong>(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
 | 
					 | 
				
			||||||
                            shapeLength *= y->e<Nd4jLong>(e2);
 | 
					 | 
				
			||||||
                        }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                        if(shapeLength == 0){
 | 
					 | 
				
			||||||
                            //Edge case for empty:
 | 
					 | 
				
			||||||
                            shapeNew[e] = 0;
 | 
					 | 
				
			||||||
                        } else {
 | 
					 | 
				
			||||||
                            Nd4jLong realShape = shape::length(inp) / shapeLength;
 | 
					 | 
				
			||||||
                            shapeNew[e] = realShape;
 | 
					 | 
				
			||||||
                        }
 | 
					 | 
				
			||||||
                    }else {
 | 
					 | 
				
			||||||
                        shapeNew[e] = dim;
 | 
					 | 
				
			||||||
                    }
 | 
					 | 
				
			||||||
                }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inp), 'c', shapeNew));
 | 
					 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if (Environment::getInstance()->isDebugAndVerbose()) {
 | 
				
			||||||
 | 
					            nd4j_printv("Reshape: new shape", shapeNew);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if (s->isEmpty()) {
 | 
				
			||||||
 | 
					            // just a scalar
 | 
				
			||||||
 | 
					            z->assign(x);
 | 
				
			||||||
 | 
					        } else {
 | 
				
			||||||
 | 
					            auto xr = x->reshape(order, shapeNew);
 | 
				
			||||||
 | 
					            z->assign(xr);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return Status::OK();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return ND4J_STATUS_BAD_INPUT;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					DECLARE_TYPES(reshape) {
 | 
				
			||||||
 | 
					    getOpDescriptor()
 | 
				
			||||||
 | 
					            ->setAllowedInputTypes(0, nd4j::DataType::ANY)
 | 
				
			||||||
 | 
					            ->setAllowedInputTypes(1, {ALL_INTS})
 | 
				
			||||||
 | 
					            ->setSameMode(true);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					DECLARE_SHAPE_FN(reshape) {
 | 
				
			||||||
 | 
					    auto inp = inputShape->at(0);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // we can launch op using Int arguments
 | 
				
			||||||
 | 
					    if (inputShape->size() == 1) {
 | 
				
			||||||
 | 
					        REQUIRE_TRUE(block.numI() > 0, 0, "Reshape: new shape should be provided as NDArray or int arguments, but nothing was defined");
 | 
				
			||||||
 | 
					        std::vector<int> *arguments = block.getIArguments();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        int e = 1;
 | 
				
			||||||
 | 
					        char order = (char) -(*arguments)[0];
 | 
				
			||||||
 | 
					        if (order != 'c' && order != 'f') {
 | 
				
			||||||
 | 
					            order = shape::order(inp);
 | 
				
			||||||
 | 
					            e = 0;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        std::vector<Nd4jLong> shapeNew;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        int e2 = e;
 | 
				
			||||||
 | 
					        for (; e < (int) arguments->size(); e++) {
 | 
				
			||||||
 | 
					            if ((int) arguments->at(e) == -1){
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                Nd4jLong shapeLength = 1;
 | 
				
			||||||
 | 
					                for(; e2 < e; e2 ++){
 | 
				
			||||||
 | 
					                    shapeLength *= arguments->at(e2);
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					                for(e2 = e + 1; e2 < arguments->size(); e2++){
 | 
				
			||||||
 | 
					                    REQUIRE_TRUE(arguments->at(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
 | 
				
			||||||
 | 
					                    shapeLength *= arguments->at(e2);
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                if(shapeLength == 0){
 | 
				
			||||||
 | 
					                    //Edge case for empty:
 | 
				
			||||||
 | 
					                    shapeNew.push_back(0);
 | 
				
			||||||
 | 
					                } else {
 | 
				
			||||||
 | 
					                    //Standard case
 | 
				
			||||||
 | 
					                    Nd4jLong realShape = shape::length(inp) / shapeLength;
 | 
				
			||||||
 | 
					                    shapeNew.push_back(realShape);
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					            else{
 | 
				
			||||||
 | 
					                shapeNew.push_back(arguments->at(e));
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(inp), order, shapeNew)));
 | 
				
			||||||
 | 
					    } else {
 | 
				
			||||||
 | 
					        // or, with second input "as shape"
 | 
				
			||||||
 | 
					        auto x = INPUT_VARIABLE(0);
 | 
				
			||||||
 | 
					        auto y = INPUT_VARIABLE(1);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        // special case here
 | 
				
			||||||
 | 
					        if (y->isEmpty()) {
 | 
				
			||||||
 | 
					            REQUIRE_TRUE(x->lengthOf() == 1, 0, "Reshape: new length doesn't match existing array");
 | 
				
			||||||
 | 
					            return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inp)));
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        //Special case: empty.reshape(-1) -> return empty
 | 
				
			||||||
 | 
					        if (x->isEmpty()) {
 | 
				
			||||||
 | 
					            //REQUIRE_TRUE(y->lengthOf() == 1 && y->e<Nd4jLong>(0) == -1, 0, "Reshape: when input is empty, shape must be [-1]");
 | 
				
			||||||
 | 
					            auto shapeOf = y->getBufferAsVector<Nd4jLong>();
 | 
				
			||||||
 | 
					            Nd4jLong prod = 1;
 | 
				
			||||||
 | 
					            bool hasNegs = false;
 | 
				
			||||||
 | 
					            for (auto v:shapeOf) {
 | 
				
			||||||
 | 
					                if (v < 0) {
 | 
				
			||||||
 | 
					                    hasNegs = true;
 | 
				
			||||||
 | 
					                    v = 0;
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                prod *= v;
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            REQUIRE_TRUE(prod == 0, 0, "Reshape: in case of empty arrays reshape must return empty array as well");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            // if there are -1s - we turn them into zeros
 | 
				
			||||||
 | 
					            if (hasNegs) {
 | 
				
			||||||
 | 
					                for (int e = 0; e < shapeOf.size(); e++)
 | 
				
			||||||
 | 
					                    if (shapeOf[e] < 0)
 | 
				
			||||||
 | 
					                        shapeOf[e] = 0;
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            auto newShape = ShapeBuilders::createShapeInfo(ArrayOptions::dataType(inp), shape::order(inp), y->lengthOf(), shapeOf.data());
 | 
				
			||||||
 | 
					            return SHAPELIST(CONSTANT(newShape));
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        std::vector<Nd4jLong> shapeNew(y->lengthOf());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for (int e = 0; e < (int) y->lengthOf(); e++) {
 | 
				
			||||||
 | 
					            auto dim = y->e<Nd4jLong>(e);
 | 
				
			||||||
 | 
					            if (dim == -1){
 | 
				
			||||||
 | 
					                Nd4jLong shapeLength = 1;
 | 
				
			||||||
 | 
					                for(int e2 = 0; e2 < e; e2++){
 | 
				
			||||||
 | 
					                    shapeLength *= y->e<Nd4jLong>(e2);
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					                for(int e2 = e + 1; e2 < (int)y->lengthOf(); e2++){
 | 
				
			||||||
 | 
					                    REQUIRE_TRUE(y->e<Nd4jLong>(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
 | 
				
			||||||
 | 
					                    shapeLength *= y->e<Nd4jLong>(e2);
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                if(shapeLength == 0){
 | 
				
			||||||
 | 
					                    //Edge case for empty:
 | 
				
			||||||
 | 
					                    shapeNew[e] = 0;
 | 
				
			||||||
 | 
					                } else {
 | 
				
			||||||
 | 
					                    Nd4jLong realShape = shape::length(inp) / shapeLength;
 | 
				
			||||||
 | 
					                    shapeNew[e] = realShape;
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					            }else {
 | 
				
			||||||
 | 
					                shapeNew[e] = dim;
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inp), 'c', shapeNew));
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#endif
 | 
					#endif
 | 
				
			||||||
@ -34,12 +34,10 @@ namespace nd4j {
 | 
				
			|||||||
        auto y = INPUT_VARIABLE(1);
 | 
					        auto y = INPUT_VARIABLE(1);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        auto z = OUTPUT_VARIABLE(0);
 | 
					        auto z = OUTPUT_VARIABLE(0);
 | 
				
			||||||
        std::vector<Nd4jLong> shapeNew(y->shapeOf(), y->shapeOf() + y->rankOf());
 | 
					 | 
				
			||||||
        char order = y->ordering();
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if (x->reshapei(order, shapeNew)) {
 | 
					        if (x->reshapei(y->ordering(), y->getShapeAsVector())) {
 | 
				
			||||||
            *z = *x;
 | 
					
 | 
				
			||||||
            STORE_RESULT(*z);
 | 
					            z->assign(x);
 | 
				
			||||||
            return Status::OK();
 | 
					            return Status::OK();
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -49,14 +47,8 @@ namespace nd4j {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    DECLARE_SHAPE_FN(reshapeas) {
 | 
					    DECLARE_SHAPE_FN(reshapeas) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    auto inputShapeInfo = inputShape->at(1);    
 | 
					        return SHAPELIST(ShapeBuilders::copyShapeInfo(INPUT_VARIABLE(1)->getShapeInfo(), false, block.workspace()));
 | 
				
			||||||
    int shapeInfoLength = inputShapeInfo[0]*2 + 4;
 | 
					    }
 | 
				
			||||||
 | 
					 | 
				
			||||||
    Nd4jLong* outputShapeInfo(nullptr);
 | 
					 | 
				
			||||||
    COPY_SHAPE(inputShapeInfo, outputShapeInfo);
 | 
					 | 
				
			||||||
    
 | 
					 | 
				
			||||||
    return SHAPELIST(CONSTANT(outputShapeInfo));
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        DECLARE_TYPES(reshapeas) {
 | 
					        DECLARE_TYPES(reshapeas) {
 | 
				
			||||||
            getOpDescriptor()
 | 
					            getOpDescriptor()
 | 
				
			||||||
 | 
				
			|||||||
@ -15,7 +15,8 @@
 | 
				
			|||||||
 ******************************************************************************/
 | 
					 ******************************************************************************/
 | 
				
			||||||
 | 
					
 | 
				
			||||||
//
 | 
					//
 | 
				
			||||||
// Created by raver119 on 29/10/17.
 | 
					// @author raver119@gmail.com
 | 
				
			||||||
 | 
					// @author Yurii Shyrma (iuriish@yahoo.com)
 | 
				
			||||||
//
 | 
					//
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#include <op_boilerplate.h>
 | 
					#include <op_boilerplate.h>
 | 
				
			||||||
@ -25,113 +26,52 @@
 | 
				
			|||||||
#include <helpers/ShapeUtils.h>
 | 
					#include <helpers/ShapeUtils.h>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
namespace nd4j {
 | 
					namespace nd4j {
 | 
				
			||||||
namespace ops {
 | 
					namespace ops  {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    //////////////////////////////////////////////////////////////////////////
 | 
					//////////////////////////////////////////////////////////////////////////
 | 
				
			||||||
    CUSTOM_OP_IMPL(transpose, 1, 1, false, 0, 0) {
 | 
					CUSTOM_OP_IMPL(transpose, 1, 1, false, 0, 0) {
 | 
				
			||||||
        auto x = INPUT_VARIABLE(0);
 | 
					 | 
				
			||||||
        if (block.width() == 1) {
 | 
					 | 
				
			||||||
            if (block.isInplace()) {
 | 
					 | 
				
			||||||
                x->transposei();
 | 
					 | 
				
			||||||
                STORE_RESULT(*x);
 | 
					 | 
				
			||||||
            } else {
 | 
					 | 
				
			||||||
                auto output = OUTPUT_VARIABLE(0);
 | 
					 | 
				
			||||||
                auto t = x->transpose();
 | 
					 | 
				
			||||||
                output->assign(t);
 | 
					 | 
				
			||||||
                STORE_RESULT(*output);
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
        } else {
 | 
					 | 
				
			||||||
            // this is tf-mode transpose, that's nd4j permute
 | 
					 | 
				
			||||||
            bool replace = false;
 | 
					 | 
				
			||||||
            std::vector<int> arguments(*block.getIArguments());
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
            auto w = block.width();
 | 
					    auto x = INPUT_VARIABLE(0);
 | 
				
			||||||
            auto a = arguments.size();
 | 
					    auto z = OUTPUT_VARIABLE(0);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            if (w == 2 && a == 0) {
 | 
					    //Special case: empty.reshape(<other empty shape>) -> return empty
 | 
				
			||||||
                auto axis = INPUT_VARIABLE(1);
 | 
					    if (x->isEmpty()) {
 | 
				
			||||||
                for (int e = 0; e < axis->lengthOf(); e++) {
 | 
					        REQUIRE_TRUE(z->isEmpty(), 0, "TRANSPOSE OP: when input is empty, output must also be empty");
 | 
				
			||||||
                    auto ax = axis->e<int>(e);
 | 
					        return Status::OK();    //No op
 | 
				
			||||||
                    if (ax < 0)
 | 
					    }
 | 
				
			||||||
                        ax += x->rankOf();
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    arguments.emplace_back(ax);
 | 
					    if (block.width() == 1 && block.getIArguments()->size() == 0) {
 | 
				
			||||||
                }
 | 
					        z->assign(x->transpose());
 | 
				
			||||||
 | 
					 | 
				
			||||||
                replace = true;
 | 
					 | 
				
			||||||
            } else if (a == 0) {
 | 
					 | 
				
			||||||
                for (int e = x->rankOf() - 1; e >= 0; e--)
 | 
					 | 
				
			||||||
                    arguments.emplace_back(e);
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            // 0D edge case
 | 
					 | 
				
			||||||
            if (x->rankOf() == 0) {
 | 
					 | 
				
			||||||
                REQUIRE_TRUE(arguments.size() == 1, 0, "Permute: only one axis is allowed for scalar");
 | 
					 | 
				
			||||||
                auto output = OUTPUT_VARIABLE(0);
 | 
					 | 
				
			||||||
                if (!block.isInplace())
 | 
					 | 
				
			||||||
                    output->assign(x);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                return Status::OK();
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            if(block.isInplace()) {		// in-place
 | 
					 | 
				
			||||||
                x->permutei(arguments);
 | 
					 | 
				
			||||||
                STORE_RESULT(x);
 | 
					 | 
				
			||||||
            } else {
 | 
					 | 
				
			||||||
                auto input = x->permute(arguments);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                auto output = OUTPUT_VARIABLE(0);
 | 
					 | 
				
			||||||
                output->assign(input);
 | 
					 | 
				
			||||||
             }
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
        return Status::OK();
 | 
					        return Status::OK();
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    DECLARE_TYPES(transpose) {
 | 
					    std::vector<int> permutationVector = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT<int>() : *block.getIArguments();
 | 
				
			||||||
        getOpDescriptor()
 | 
					 | 
				
			||||||
                ->setAllowedInputTypes(nd4j::DataType::ANY)
 | 
					 | 
				
			||||||
                ->setSameMode(true);
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    DECLARE_SHAPE_FN(transpose) {
 | 
					    z->assign(x->permute(permutationVector));
 | 
				
			||||||
        if (block.width() == 1) {
 | 
					 | 
				
			||||||
            auto outputShapeInfo = ShapeUtils::evalTranspShapeInfo(*INPUT_VARIABLE(0), block.workspace());
 | 
					 | 
				
			||||||
            return SHAPELIST(outputShapeInfo);
 | 
					 | 
				
			||||||
        } else {
 | 
					 | 
				
			||||||
            // this is basically permute mode
 | 
					 | 
				
			||||||
            auto shapeList = SHAPELIST();
 | 
					 | 
				
			||||||
            auto arguments = block.getIArguments();
 | 
					 | 
				
			||||||
            if (shape::rank(inputShape->at(0)) == 0) {
 | 
					 | 
				
			||||||
                Nd4jLong *newshape;
 | 
					 | 
				
			||||||
                ALLOCATE(newshape, block.getWorkspace(), shape::shapeInfoLength(inputShape->at(0)), Nd4jLong);
 | 
					 | 
				
			||||||
                newshape[0] = 0;
 | 
					 | 
				
			||||||
                newshape[1] = 0;
 | 
					 | 
				
			||||||
                newshape[2] = 1;
 | 
					 | 
				
			||||||
                newshape[3] = 99;
 | 
					 | 
				
			||||||
                ArrayOptions::copyDataType(newshape, inputShape->at(0));
 | 
					 | 
				
			||||||
                shapeList->push_back(newshape);
 | 
					 | 
				
			||||||
            } else if (arguments->size() > 0 || inputShape->size() > 1) {
 | 
					 | 
				
			||||||
                auto axis = arguments->size() > 0 ? *arguments : (INPUT_VARIABLE(1))->template asVectorT<int>();
 | 
					 | 
				
			||||||
                auto outputShapeInfo = ShapeUtils::evalPermShapeInfo(axis.data(), axis.size(), *INPUT_VARIABLE(0), block.workspace());
 | 
					 | 
				
			||||||
                shapeList->push_back(outputShapeInfo);
 | 
					 | 
				
			||||||
            } else if (inputShape->size() == 2) {
 | 
					 | 
				
			||||||
                // dead end
 | 
					 | 
				
			||||||
                auto axis = INPUT_VARIABLE(1);
 | 
					 | 
				
			||||||
                auto axisV = axis->template asVectorT<Nd4jLong>();
 | 
					 | 
				
			||||||
                auto newshape = ShapeUtils::evalPermShapeInfo(axisV.data(), axisV.size(), *INPUT_VARIABLE(0), block.workspace());
 | 
					 | 
				
			||||||
                shapeList->push_back(newshape);
 | 
					 | 
				
			||||||
            } else {
 | 
					 | 
				
			||||||
                int rank = shape::rank(inputShape->at(0));
 | 
					 | 
				
			||||||
                for (int e = rank - 1; e >= 0; e--)
 | 
					 | 
				
			||||||
                    arguments->emplace_back(e);
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
                auto outputShapeInfo = ShapeUtils::evalPermShapeInfo(arguments->data(), arguments->size(), *INPUT_VARIABLE(0), block.workspace());
 | 
					    return Status::OK();
 | 
				
			||||||
                shapeList->push_back(outputShapeInfo);
 | 
					}
 | 
				
			||||||
            }
 | 
					
 | 
				
			||||||
 | 
					DECLARE_TYPES(transpose) {
 | 
				
			||||||
 | 
					    getOpDescriptor()
 | 
				
			||||||
 | 
					            ->setAllowedInputTypes(nd4j::DataType::ANY)
 | 
				
			||||||
 | 
					            ->setSameMode(true);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					DECLARE_SHAPE_FN(transpose) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    auto x = INPUT_VARIABLE(0);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if (block.width() == 1 && block.getIArguments()->size() == 0)
 | 
				
			||||||
 | 
					        return SHAPELIST(ShapeUtils::evalTranspShapeInfo(*x, block.workspace(), true));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    std::vector<int> permutationVector = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT<int>() : *block.getIArguments();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    auto outputShapeInfo = ShapeUtils::evalPermShapeInfo(permutationVector.data(), x->rankOf(), *x, block.workspace(), true);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return SHAPELIST(outputShapeInfo);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            return shapeList;
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -1882,36 +1882,6 @@ TEST_F(DeclarableOpsTests1, TestGemv1) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
#endif
 | 
					#endif
 | 
				
			||||||
 | 
					
 | 
				
			||||||
//////////////////////////////////////////////////////////////////////
 | 
					 | 
				
			||||||
TEST_F(DeclarableOpsTests1, Reshape1) {
 | 
					 | 
				
			||||||
    const std::vector<Nd4jLong> xShape = {5,4,3};
 | 
					 | 
				
			||||||
    const std::vector<Nd4jLong> yShape = {3,5,4};
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    auto x = NDArrayFactory::create_<float>('f', xShape);
 | 
					 | 
				
			||||||
    auto y = NDArrayFactory::create_<float>('f', yShape);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    auto variableSpace = new VariableSpace();
 | 
					 | 
				
			||||||
    variableSpace->putVariable(-1, x);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    auto block = new Context(1, variableSpace, true);
 | 
					 | 
				
			||||||
    block->fillInputs({-1});
 | 
					 | 
				
			||||||
    std::vector<int>* arguments = block->getIArguments();
 | 
					 | 
				
			||||||
    arguments->push_back(-y->ordering());
 | 
					 | 
				
			||||||
    arguments->push_back(3);
 | 
					 | 
				
			||||||
    arguments->push_back(5);
 | 
					 | 
				
			||||||
    arguments->push_back(4);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    nd4j::ops::reshape reshape;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    reshape.execute(block);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    ASSERT_TRUE(x->isSameShape(y));
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    delete y;
 | 
					 | 
				
			||||||
    delete block;
 | 
					 | 
				
			||||||
    delete variableSpace;
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
//////////////////////////////////////////////////////////////////////
 | 
					//////////////////////////////////////////////////////////////////////
 | 
				
			||||||
TEST_F(DeclarableOpsTests1, Reshape2) {
 | 
					TEST_F(DeclarableOpsTests1, Reshape2) {
 | 
				
			||||||
    const std::vector<Nd4jLong> xShape = {5,4,3};
 | 
					    const std::vector<Nd4jLong> xShape = {5,4,3};
 | 
				
			||||||
@ -2022,37 +1992,8 @@ TEST_F(DeclarableOpsTests1, Reshape7){
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
//////////////////////////////////////////////////////////////////////
 | 
					//////////////////////////////////////////////////////////////////////
 | 
				
			||||||
TEST_F(DeclarableOpsTests1, Transpose1) {
 | 
					TEST_F(DeclarableOpsTests1, Transpose1) {
 | 
				
			||||||
 | 
					 | 
				
			||||||
    auto x = NDArrayFactory::create_<float>('c', {3,5,2});
 | 
					    auto x = NDArrayFactory::create_<float>('c', {3,5,2});
 | 
				
			||||||
    auto exp = NDArrayFactory::create_<float>('f', {2,5,3});
 | 
					    auto exp = NDArrayFactory::create_<float>('c', {2,5,3});
 | 
				
			||||||
 | 
					 | 
				
			||||||
    auto variableSpace = new VariableSpace();
 | 
					 | 
				
			||||||
    variableSpace->putVariable(-1, x);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    auto block = new Context(1, variableSpace, true);  // in-place
 | 
					 | 
				
			||||||
    block->fillInputs({-1});
 | 
					 | 
				
			||||||
    nd4j::ops::transpose transpose;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Nd4jStatus status = transpose.execute(block);
 | 
					 | 
				
			||||||
    ASSERT_EQ(ND4J_STATUS_OK, status);
 | 
					 | 
				
			||||||
    // ASSERT_TRUE(x.isSameShapeStrict(exp));
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    for (int e = 0; e < x->rankOf() * 2 + 2; e++) {
 | 
					 | 
				
			||||||
        ASSERT_EQ(x->getShapeInfo()[e], exp->getShapeInfo()[e]);
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
//  ASSERT_EQ(x.getShapeInfo()[x.rankOf() * 2 + 2],-exp.getShapeInfo()[x.rankOf() * 2 + 2]);
 | 
					 | 
				
			||||||
    ASSERT_EQ(x->getShapeInfo()[x->rankOf() * 2 + 3], exp->getShapeInfo()[x->rankOf() * 2 + 3]);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    delete exp;
 | 
					 | 
				
			||||||
    delete block;
 | 
					 | 
				
			||||||
    delete variableSpace;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
//////////////////////////////////////////////////////////////////////
 | 
					 | 
				
			||||||
TEST_F(DeclarableOpsTests1, Transpose2) {
 | 
					 | 
				
			||||||
    auto x = NDArrayFactory::create_<float>('c', {3,5,2});
 | 
					 | 
				
			||||||
    auto exp = NDArrayFactory::create_<float>('f', {2,5,3});
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    auto variableSpace = new VariableSpace();
 | 
					    auto variableSpace = new VariableSpace();
 | 
				
			||||||
    variableSpace->putVariable(-1, x);
 | 
					    variableSpace->putVariable(-1, x);
 | 
				
			||||||
@ -2066,12 +2007,10 @@ TEST_F(DeclarableOpsTests1, Transpose2) {
 | 
				
			|||||||
    ASSERT_EQ(ND4J_STATUS_OK, status);
 | 
					    ASSERT_EQ(ND4J_STATUS_OK, status);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    auto result = variableSpace->getVariable(block->getNodeId())->getNDArray();
 | 
					    auto result = variableSpace->getVariable(block->getNodeId())->getNDArray();
 | 
				
			||||||
    // ASSERT_TRUE(result->isSameShapeStrict(exp));
 | 
					
 | 
				
			||||||
    for (int e = 0; e < result->rankOf() * 2 + 2; e++) {
 | 
					    ASSERT_TRUE(exp->isSameShape(result));
 | 
				
			||||||
        ASSERT_EQ(result->getShapeInfo()[e], exp->getShapeInfo()[e]);
 | 
					    ASSERT_TRUE(exp->dataType() == result->dataType());
 | 
				
			||||||
    }
 | 
					    ASSERT_TRUE(exp->ordering() == result->ordering());
 | 
				
			||||||
    //ASSERT_EQ(result->getShapeInfo()[x.rankOf() * 2 + 2],-exp.getShapeInfo()[x.rankOf() * 2 + 2]);
 | 
					 | 
				
			||||||
    ASSERT_EQ(result->getShapeInfo()[x->rankOf() * 2 + 3], exp->getShapeInfo()[x->rankOf() * 2 + 3]);
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    delete exp;
 | 
					    delete exp;
 | 
				
			||||||
    delete block;
 | 
					    delete block;
 | 
				
			||||||
@ -2079,44 +2018,12 @@ TEST_F(DeclarableOpsTests1, Transpose2) {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
//////////////////////////////////////////////////////////////////////
 | 
					 | 
				
			||||||
// in-place
 | 
					 | 
				
			||||||
TEST_F(DeclarableOpsTests1, Permute1) {
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Nd4jLong shapeX[]   = {3, 5, 10, 15, 150, 15, 1, 0, 1, 99};
 | 
					 | 
				
			||||||
    Nd4jLong shapeExp[] = {3, 15, 5, 10, 1, 150, 15, 0, 0, 99};
 | 
					 | 
				
			||||||
    const std::vector<int> perm = {2, 0, 1};
 | 
					 | 
				
			||||||
    ArrayOptions::setDataType(shapeX, nd4j::DataType::FLOAT32);
 | 
					 | 
				
			||||||
    ArrayOptions::setDataType(shapeExp, nd4j::DataType::FLOAT32);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    auto x = new NDArray(shapeX,true);
 | 
					 | 
				
			||||||
    auto exp = new NDArray(shapeExp,true);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    auto variableSpace = new VariableSpace();
 | 
					 | 
				
			||||||
    variableSpace->putVariable(-1, x);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    auto block = new Context(1, variableSpace, true);  // in-place
 | 
					 | 
				
			||||||
    block->fillInputs({-1});
 | 
					 | 
				
			||||||
    std::vector<int>* arguments = block->getIArguments();
 | 
					 | 
				
			||||||
    *arguments = perm;      // set dimensions to be permuted
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    nd4j::ops::permute permute;
 | 
					 | 
				
			||||||
    Nd4jStatus status = permute.execute(block);
 | 
					 | 
				
			||||||
    ASSERT_EQ(ND4J_STATUS_OK, status);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    ASSERT_TRUE(x->isSameShapeStrict(*exp));
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    delete exp;
 | 
					 | 
				
			||||||
    delete block;
 | 
					 | 
				
			||||||
    delete variableSpace;
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
//////////////////////////////////////////////////////////////////////
 | 
					//////////////////////////////////////////////////////////////////////
 | 
				
			||||||
// not-in-place
 | 
					// not-in-place
 | 
				
			||||||
TEST_F(DeclarableOpsTests1, Permute2) {
 | 
					TEST_F(DeclarableOpsTests1, Permute1) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    Nd4jLong shapeX[]   = {3, 5, 10, 15, 150, 15, 1, 0, 1, 99};
 | 
					    Nd4jLong shapeX[]   = {3, 5,10,15,  150,15,1,  0,1,99};
 | 
				
			||||||
    Nd4jLong shapeExp[] = {3, 15, 5, 10, 1, 150, 15, 0, 0, 99};
 | 
					    Nd4jLong shapeExp[] = {3, 15,5,10,  50,10,1,  0,1,99};
 | 
				
			||||||
    const std::vector<int> perm = {2, 0, 1};
 | 
					    const std::vector<int> perm = {2, 0, 1};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    ArrayOptions::setDataType(shapeX, nd4j::DataType::FLOAT32);
 | 
					    ArrayOptions::setDataType(shapeX, nd4j::DataType::FLOAT32);
 | 
				
			||||||
 | 
				
			|||||||
@ -161,23 +161,6 @@ TEST_F(EmptyTests, Test_Reshape_1) {
 | 
				
			|||||||
    delete result;
 | 
					    delete result;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
TEST_F(EmptyTests, Test_Reshape_2) {
 | 
					 | 
				
			||||||
    auto vector = NDArrayFactory::create<float>('c', {1}, {119.0f});
 | 
					 | 
				
			||||||
    auto exp = NDArrayFactory::create<float>(119.0f);
 | 
					 | 
				
			||||||
    auto empty = NDArrayFactory::empty_<Nd4jLong>();
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    nd4j::ops::reshape op;
 | 
					 | 
				
			||||||
    auto result = op.evaluate({&vector, empty}, {}, {}, {}, {}, true);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    ASSERT_EQ(Status::OK(), result->status());
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    ASSERT_EQ(exp, *result->at(0));
 | 
					 | 
				
			||||||
    ASSERT_EQ(exp, vector);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    delete empty;
 | 
					 | 
				
			||||||
    delete result;
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
TEST_F(EmptyTests, Test_Reshape_3) {
 | 
					TEST_F(EmptyTests, Test_Reshape_3) {
 | 
				
			||||||
    auto x = NDArrayFactory::create<float>('c', {1, 0, 0, 2});
 | 
					    auto x = NDArrayFactory::create<float>('c', {1, 0, 0, 2});
 | 
				
			||||||
    auto y = NDArrayFactory::create<int>('c', {2}, {10, 0});
 | 
					    auto y = NDArrayFactory::create<int>('c', {2}, {10, 0});
 | 
				
			||||||
 | 
				
			|||||||
@ -65,7 +65,7 @@ TEST_F(PlaygroundTests, test_avx) {
 | 
				
			|||||||
    nd4j_printf("Optimal level: %i; Binary level: %i;\n", ::optimalLevel(), ::binaryLevel());
 | 
					    nd4j_printf("Optimal level: %i; Binary level: %i;\n", ::optimalLevel(), ::binaryLevel());
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/*
 | 
				
			||||||
TEST_F(PlaygroundTests, test_bert_1) {
 | 
					TEST_F(PlaygroundTests, test_bert_1) {
 | 
				
			||||||
    // this test will run ONLY if this model exists
 | 
					    // this test will run ONLY if this model exists
 | 
				
			||||||
    if (nd4j::graph::getFileSize("/home/raver119/Downloads/Bert_minimal_model/bert_minimal_model.fb") < 0)
 | 
					    if (nd4j::graph::getFileSize("/home/raver119/Downloads/Bert_minimal_model/bert_minimal_model.fb") < 0)
 | 
				
			||||||
@ -86,15 +86,15 @@ TEST_F(PlaygroundTests, test_bert_1) {
 | 
				
			|||||||
    graph->getVariableSpace()->putVariable(86,0, u);
 | 
					    graph->getVariableSpace()->putVariable(86,0, u);
 | 
				
			||||||
    graph->getVariableSpace()->putVariable(87,0, v);
 | 
					    graph->getVariableSpace()->putVariable(87,0, v);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
/*
 | 
					 | 
				
			||||||
    // validating graph now
 | 
					 | 
				
			||||||
    auto status = GraphExecutioner::execute(graph);
 | 
					 | 
				
			||||||
    ASSERT_EQ(Status::OK(), status);
 | 
					 | 
				
			||||||
    ASSERT_TRUE(graph->getVariableSpace()->hasVariable(198));
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    auto array = graph->getVariableSpace()->getVariable(198)->getNDArray();
 | 
					    // validating graph now
 | 
				
			||||||
    ASSERT_EQ(z, *array);
 | 
					    // auto status = GraphExecutioner::execute(graph);
 | 
				
			||||||
*/
 | 
					    // ASSERT_EQ(Status::OK(), status);
 | 
				
			||||||
 | 
					    // ASSERT_TRUE(graph->getVariableSpace()->hasVariable(198));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // auto array = graph->getVariableSpace()->getVariable(198)->getNDArray();
 | 
				
			||||||
 | 
					    // ASSERT_EQ(z, *array);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    nd4j::Environment::getInstance()->setProfiling(true);
 | 
					    nd4j::Environment::getInstance()->setProfiling(true);
 | 
				
			||||||
    auto profile = GraphProfilingHelper::profile(graph, 1);
 | 
					    auto profile = GraphProfilingHelper::profile(graph, 1);
 | 
				
			||||||
@ -104,28 +104,27 @@ TEST_F(PlaygroundTests, test_bert_1) {
 | 
				
			|||||||
    nd4j::Environment::getInstance()->setProfiling(false);
 | 
					    nd4j::Environment::getInstance()->setProfiling(false);
 | 
				
			||||||
    delete profile;
 | 
					    delete profile;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
/*
 | 
					 | 
				
			||||||
    std::vector<Nd4jLong> values;
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    for (int e = 0; e < 1; e++) {
 | 
					    // std::vector<Nd4jLong> values;
 | 
				
			||||||
        auto timeStart = std::chrono::system_clock::now();
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        GraphExecutioner::execute(graph);
 | 
					    // for (int e = 0; e < 1; e++) {
 | 
				
			||||||
 | 
					    //     auto timeStart = std::chrono::system_clock::now();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        auto timeEnd = std::chrono::system_clock::now();
 | 
					    //     GraphExecutioner::execute(graph);
 | 
				
			||||||
        auto outerTime = std::chrono::duration_cast<std::chrono::microseconds>(timeEnd - timeStart).count();
 | 
					 | 
				
			||||||
        values.emplace_back(outerTime);
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    std::sort(values.begin(), values.end());
 | 
					    //     auto timeEnd = std::chrono::system_clock::now();
 | 
				
			||||||
 | 
					    //     auto outerTime = std::chrono::duration_cast<std::chrono::microseconds>(timeEnd - timeStart).count();
 | 
				
			||||||
 | 
					    //     values.emplace_back(outerTime);
 | 
				
			||||||
 | 
					    // }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    nd4j_printf("Time: %lld us;\n", values[values.size() / 2]);
 | 
					    // std::sort(values.begin(), values.end());
 | 
				
			||||||
*/
 | 
					
 | 
				
			||||||
 | 
					    // nd4j_printf("Time: %lld us;\n", values[values.size() / 2]);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    delete graph;
 | 
					    delete graph;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
/*
 | 
					
 | 
				
			||||||
TEST_F(PlaygroundTests, test_broadcast_1) {
 | 
					TEST_F(PlaygroundTests, test_broadcast_1) {
 | 
				
			||||||
    int pool = 10;
 | 
					    int pool = 10;
 | 
				
			||||||
    std::vector<NDArray*> aX(pool);
 | 
					    std::vector<NDArray*> aX(pool);
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user