cavis/libnd4j/include/helpers/ShapeUtils.h

213 lines
12 KiB
C
Raw Normal View History

2019-06-06 14:21:15 +02:00
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Yurii Shyrma (iuriish@yahoo.com)
2019-06-06 14:21:15 +02:00
//
#ifndef LIBND4J_SHAPEUTILS_H
#define LIBND4J_SHAPEUTILS_H
#include <vector>
#include <NDArray.h>
namespace nd4j {
class ND4J_EXPORT ShapeUtils {
2019-06-06 14:21:15 +02:00
public:
// evaluate shape for array resulting from tensorDot operation, also evaluate shapes and permutation dimensions for transposition of two input arrays
static std::vector<Nd4jLong> evalShapeForTensorDot(const Nd4jLong* aShapeInfo, const Nd4jLong* bShapeInfo, std::vector<int> axesA, std::vector<int> axesB, std::vector<int>& permutAt, std::vector<int>& permutBt, std::vector<Nd4jLong>& shapeAt, std::vector<Nd4jLong>& shapeBt);
static std::vector<Nd4jLong> evalShapeForTensorDot(const NDArray* a, const NDArray* b, const std::vector<int>& axesA, const std::vector<int>& axesB, std::vector<int>& permutAt, std::vector<int>& permutBt, std::vector<Nd4jLong>& shapeAt, std::vector<Nd4jLong>& shapeBt);
// evaluate resulting shape after reduce operation
static Nd4jLong* evalReduceShapeInfo(const char order, std::vector<int>& dimensions, const NDArray& arr, const nd4j::DataType dataType, const bool keepDims = false, const bool supportOldShapes = false, nd4j::memory::Workspace* workspace = nullptr);
static Nd4jLong* evalReduceShapeInfo(const char order, std::vector<int>& dimensions, const Nd4jLong* shapeInfo, const nd4j::DataType dataType, const bool keepDims = false, const bool supportOldShapes = false, nd4j::memory::Workspace* workspace = nullptr);
static Nd4jLong* evalReduceShapeInfo(const char order, std::vector<int>& dimensions, const NDArray& arr, const bool keepDims = false, const bool supportOldShapes = false, nd4j::memory::Workspace* workspace = nullptr);
static Nd4jLong* evalReduceShapeInfo(const char order, std::vector<int>& dimensions, const Nd4jLong* shapeInfo, const bool keepDims = false, const bool supportOldShapes = false, nd4j::memory::Workspace* workspace = nullptr);
Dev branch merge: dev_20190606 (#7904) * correct logsoftmax looss (#2) * Small SameDiff listener fix (#4) * Various fixes (#6) * #7839 Fix for asXMatrix and tests * #7866 EmbeddingSequenceLayer dtype fix + test * #7856 SameDiff save/load stream methods * #7859 RegressionEvaluation rank 4 fix + tests + axis configuration * EvaluationBinary 3d/4d * More evaluation 3d/4d tests * #7847 Evaluation empty checks * Small test ifx * #7848 Fix median edge case * Improve DL4J samediff layer tests * [WIP] FastText wrapper implemented (#8) * FastText implemented * Some fixes * Fix shapes for wordsNearest * Validation of input vectors * Fixes * Fixed test * Thread tagged * Some tweaks * setContextClassLoader for DeallocatorServiceThread * Numpy format tests (#1) * Various fixes (#11) * #7852 SameDiff gather fix * #7892 SameDiff placeholder to constant conversion * #7890 validate input rank for MLN/CG init methods * Fix broken permute shape calculation * Permute and gather fixes * Tests * #7850 LogSumExp fix + test * Handful of test fixes * Empty arrays with non-scalar shapes (#10) * minor rearrangements for lambdas * empty tensors with non-scalar shapes * numpy empty tensors with non-scalar shapes * few more empty tweaks * Small fixes * conv3d signature update * micro fix in batchnorm mkldnn * Import fixes * Fix * MKL-DNN update * Small fill fix * fill with empty input + test * Fixes * Small error improvement * Fix * one special test * couple of fixes for lstm * Rewrite TFGraphMapper.getNDArrayFromTensor to be maintainable and less error prone * Fixes * FP16 * Unsigned * BFloat16 * Fill op - empty tweaks * - couple of fixes for empty arrays construction - stack updated * strided slice fix * one transform test * provide method for reducing shapeInfo in case of input array is empty * Fixed reduceAlongDimensions to use empty input properly. * couple of broadcast tests * couple of tests broadcast tests + tweak to make them pass * add check of non-empty to methods producing sub-arrays * Fixed reshapeC with zeros in shape. * complete empty check in reduce_... legacy ops * Concat and cumsum/prod * Tweak to empty shape inference on import * add empty check to the rest of reduce legacy ops * one more test * correct typo in evalReduceShapeInfoEmpty * Added tests for reduce_* ops to tests with zero shapes. * few more tests for empty reductions * Fixed strided_slice op with empty case and tests. * one more empty reduction test * Fixed strided_slice test. * add empty check to NDArray::reshapei * infOrMax * empty min/max with infinity tests * made unstack working correctly with empty arrays * few IndexReduce tests + tweaks for empty shapes * add test for empty concat * few tests fixed * Validation fix for reductions on empty shapes * Reverse fix * Reduction shape calc fixes * SameDiff.generateOutputVariable: don't use shape function to determine number of outputs * Range fix * - NDArray constructor updated for scalars/empty arrays - few tests fixed * More fixes * Empty creator fixes * concat fix * concat fix * TF import tests: allow 'both all NaN' and 'both all inf' to pass * Slice, zero fraction, and reshape fixes * transpose, gather * Zero fraction * scalar cast fix * Empty reduction axis support * few more tests fixed * Fixed input checks conforming with TF for concat op and tests. * few tests fixed * matmul scalar shape fix * Fixed checkout for data type and scalarity with concat to allow non-empty scalars with vector concats. * broadcast bool fix * few more tests * few more tests * correct evalReduceShapeInfoEmpty * argmax/argmin + tests * one more empty edge case + one more test * argmax/argmin/realdiv_bp tweaks * empty reshape test + fix * Helper fixes * Small fixes * Gather test fix * Gather test fix * Small fixes * reduce scalar zero values * scalar mean workaround * Remove debug code * along dim mean workaround * one more test * - equalsTo() tweak for empty arrays - one more test * broadcast tweaks
2019-06-15 13:34:34 +02:00
/**
* evaluate output shape for reduce operation when input shape is empty
* behavior is analogous to tf
*/
static Nd4jLong* evalReduceShapeInfoEmpty(const char order, std::vector<int>& dimensions, const Nd4jLong *shapeInfo, const nd4j::DataType dataType, const bool keepDims, nd4j::memory::Workspace* workspace);
2019-06-06 14:21:15 +02:00
// evaluate shape for array which is result of repeat operation applied to arr
static std::vector<Nd4jLong> evalRepeatShape(int axis, const std::vector<int>& repeats, const NDArray& arr);
2019-06-06 14:21:15 +02:00
// evaluate shapeInfo of permuted array
static Nd4jLong* evalPermShapeInfo(const int* dimensions, const int rank, const NDArray& arr, nd4j::memory::Workspace* workspace);
static Nd4jLong* evalPermShapeInfo(const Nd4jLong* dimensions, const int rank, const NDArray& arr, nd4j::memory::Workspace* workspace);
// evaluate shapeInfo of transposed array
static Nd4jLong* evalTranspShapeInfo(const NDArray& arr, nd4j::memory::Workspace* workspace);
static bool copyVectorPart(std::vector<int>& target, std::vector<int>& source, int rank, int offset);
// return new (shorter) sorted dimensions array without dimensions that are present in input vector
static std::vector<int> evalDimsToExclude(const int rank, const int dimsLen, const int* dimensions);
static std::vector<int> evalDimsToExclude(const int rank, const std::vector<int>& dimensions);
// check whether 2 arrays have mutually broadcastable shapes
// shape comparison starts from the end
static bool areShapesBroadcastable(const NDArray &arr1, const NDArray &arr2);
static bool areShapesBroadcastable(Nd4jLong* shapeX, Nd4jLong* shapeY);
static bool areShapesBroadcastable(const std::vector<Nd4jLong>& shape1, const std::vector<Nd4jLong>& shape2);
// check the possibility of broadcast operation, if true then return shapeInfo of resulting array
// if evalMinMax == false then array with larger rank has to be passed as first argument
static bool evalBroadcastShapeInfo(const NDArray& max, const NDArray& min, const bool evalMinMax, Nd4jLong*& resultShapeInfo, nd4j::memory::Workspace* workspace);
static bool evalBroadcastShapeInfo(Nd4jLong *max, Nd4jLong *min, const bool evalMinMax, Nd4jLong*& resultShapeInfo, nd4j::memory::Workspace* workspace);
// evaluate sorted vector of max axes to create tads along in case of simple broadcast operation
// if simple broadcast is not possible then empty vector is returned
// PLEASE NOTE: condition (rank_max >= rank_min) should be satisfied !
static std::vector<int> tadAxesForSimpleBroadcast(const NDArray& max, const NDArray& min);
// check the possibility of broadcast operation for set of arrays, if true then return resulting broadcasted shapeInfo
static bool evalCommonBroadcastShapeInfo(const std::vector<const NDArray*>& arrays, Nd4jLong*& resultShapeInfo, memory::Workspace* workspace = nullptr);
// return sorted vector of dimensions common (same) for two arrays, dimensions values corresponds to array with bigger rank
// for example if arr1{2,7}, arr2{2,5,4,7} then vector = {0,3}
2019-06-06 14:21:15 +02:00
static std::vector<int> getDimsWithSameShape(const NDArray& max, const NDArray& min);
// evaluate shapeInfo for resulting array of tile operation
static Nd4jLong* evalTileShapeInfo(const NDArray& arr, const std::vector<Nd4jLong>& reps, nd4j::memory::Workspace* workspace);
// returns shape part of shapeInfo as std::vector
static std::vector<Nd4jLong> pullShapeFromShapeInfo(Nd4jLong *shapeInfo);
static std::string shapeAsString(const NDArray* array);
static std::string shapeAsString(const std::vector<Nd4jLong>& shape);
static std::string shapeAsString(const Nd4jLong* shapeInfo);
static std::string shapeAsString(const int rank, const Nd4jLong* shapeInfo);
static std::string strideAsString(const NDArray* array);
static std::vector<Nd4jLong> shapeAsVector(const Nd4jLong* shapeInfo);
2019-06-06 14:21:15 +02:00
// evaluate shapeInfo for diagonal array which is made using input arr elements as diagonal
static Nd4jLong* evalDiagShapeInfo(const Nd4jLong* shapeInfo, nd4j::memory::Workspace* workspace);
static std::vector<int> evalBroadcastBackwardAxis(const Nd4jLong *operand, const Nd4jLong *result);
// utility to calculate matrix product shape with give source shapes and additional params
// returns ShapeList pointer with result shape
static Nd4jLong* matrixProductShape(Nd4jLong* theFirstShape, Nd4jLong* theSecondShape, bool shouldTranspondFirst, bool shouldTranspondSecond, nd4j::DataType dtype, nd4j::memory::Workspace* workspace);
/**
* This method evaluates permutation vector necessary for reducing of shapeFrom to shapeTo
* if shapeFrom is identical to shapeTo (permutation is unnecessary) then empty vector is returned
* in case of permutation is impossible an exception is thrown
*/
static std::vector<int> evalPermutFromTo(const std::vector<Nd4jLong>& shapeFrom, const std::vector<Nd4jLong>& shapeTo);
/**
* This method composes shape (shape only, not whole shapeInfo!) using dimensions values and corresponding indexes,
* please note: the size of input vector dimsAndIdx must always be even, since the numbers of dimensions and indexes are the same,
* for example if dimsAndIdx = {dimC,dimB,dimA, 2,1,0} then output vector = {dimA,dimB,dimC}
*/
static std::vector<Nd4jLong> composeShapeUsingDimsAndIdx(const std::vector<int>& dimsAndIdx);
/**
* x * y = c, evaluate shape for array resulting from mmul operation
* possible cases: dot product (xRank=yRank=1), matrix-vector product (xRank=2, yRank=1), vector-matrix product (xRank=1, yRank=2), matrix-matrix product (xRank=yRank and rank >=2)
*/
static std::vector<Nd4jLong> evalShapeForMatmul(const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, const bool transX, const bool transY);
/**
* evaluate number of sub-arrays along dimensions stored in dimsToExclude
* i.e. if shape is [2,3,4,5] and dimsToExclude={0,2}, then number of sub-arrays = 8
*/
static Nd4jLong getNumOfSubArrs(const Nd4jLong* shapeInfo, const std::vector<int>& dimsToExclude);
/**
* evaluate indexes ranges that define sub-array of array having shape=shapeInfo
* subArrIdx - index of current sub-array
* shapeInfo - shapeInfo of array for which to evaluate sub-arrays
* dimsToExclude - MUST BE SORTED, dimensions to evaluate sub-arrays along, i.e. when shape is [2,3,4,5] and dimsToExclude={0,2}, then there will be 8 sub-arrays with shape [3,5],
* if dimsToExclude is empty then idxRanges containing all zeros (means whole array) will be returned.
* idxRanges - where to put result, the length of idxRanges must be equal to 2*shapeInfo[0]
*/
static void evalIdxRangesForSubArr(const Nd4jLong subArrIdx, const Nd4jLong* shapeInfo, const std::vector<int>& dimsToExclude, Nd4jLong* idxRanges);
/**
* return shape without unities, for example if shape is [1,2,1,3] then [2,3] will be returned
* if unities are not present in given shapeInfo then exactly identical shape will be returned, for example [2,3] -> [2,3]
* edge case: if given shape is [1,1,1,...,1] (all dims are unities) then output will be empty and means scalar
*/
static std::vector<Nd4jLong> evalDimsWithoutUnities(const Nd4jLong* shapeInfo);
/**
* method returns false if permut == {0,1,2,...permut.size()-1} - in that case permutation is unnecessary
*/
FORCEINLINE static bool isPermutNecessary(const std::vector<int>& permut);
/**
* calculates strides using "dest" shape and given "order", also copies data type from "source" to "dest"
*/
static void updateStridesAndType(Nd4jLong* dest, const Nd4jLong* source, const char order);
/**
* calculates strides using "dest" shape and "order", also set "dtype" into "dest"
*/
static void updateStridesAndType(Nd4jLong* dest, const DataType dtype, const char order);
/**
* This method retuns number of bytes required for string tensor
* @param numStrings
* @return
*/
Oleh convert (#200) * StringUtils for utf convertor raw implementation of all possible combinations, need to be add counter of bytes per symbol for any type and add api to call convertors and store data Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor more corrections to support convertors Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor some corrections and bug fixes, need review to discuss how to add multi-threading Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some corrections to move to multi-threading, add one test need discussion data inputs/outputs array presentation, need discussion the way of multi-threading * StringUtils for utf convertor #8613 tests added some corrections to optimize build Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some corrections and code clean up Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 code clean up and optimize usage, need update ndarray factory before replace std usage Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some staff to integrate converters into NDArrayFactory, update tests and add some functionality Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 minor corrections and bug fix before discussion * StringUtils for utf convertor #8613 some fixes and tets * StringUtils for utf convertor #8613 some more staff to support different unicode Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fix linking bug * StringUtils for utf convertor #8613 corrected several tests as defaults for string ndarray changed * StringUtils for utf convertor #8613 replace some incorrect implementation, revert some test changes, need sync before testing * StringUtils for utf convertor #8613 fixed several thing that were badly implemented yesterday, need optimization, testing (before testing have to be add support of u32 and u16 buffer visualization) * StringUtils for utf convertor #8613 fixed to support u16 and u32, and convertor in ndarray, fix buffer print, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 merge master and sync with server Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some correction for string cast, need print check only asci support Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 merge master, remove copies and add cast, need test, refactoring according review and clean up * StringUtils for utf convertor #8613 fixed cast and copy issues Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed cuda and update tests * StringUtils for utf convertor #8613 integration into NdArray, fix several tests for build pass, refactoring, etc * - avoid ambiguity of NDArray ctrs overloading in some tests Signed-off-by: Yurii <iuriish@yahoo.com> * StringUtils for utf convertor #8613 NDArray string constructors added, updated NDArrayFactory, refactoring unicode and tests, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed cuda build and test, refactoring and void* added to some functions Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 void* integration, removed copy operation, refactoring, added tests for NDArray string constructors, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 several more fixes, improvements and updates Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 master merge, code clean up and optimization before review Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 minor fixes string element size define Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 revert last changes as mistake Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed NDArray constructor build problem, remove order from string factory, fixed order use for factory via project, added catch of incorrect sync in cast of arrays to data types, fixed e method for strings, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 added javacpp hack, added multi-threading, minor corrections in license agreement Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 windows builds fix, as "sting" is not treated as utf8 Signed-off-by: Oleg <oleg.semeniv@gmail.com> Co-authored-by: Yurii Shyrma <iuriish@yahoo.com>
2020-01-31 14:30:49 +01:00
static FORCEINLINE Nd4jLong stringBufferHeaderRequirements(Nd4jLong numStrings) {
// we store +1 offset
return (numStrings + 1) * sizeof(Nd4jLong);
}
/*
* check whether arr1/arr2 is sub-array of arr2/arr1,
* this method do not evaluate what array is sub-array, it returns true if arr1 is sub-array of arr2 or arr2 is sub-array of arr1
* sameDims is filled (and sorted) with dimensions values that match both in arr1 and arr2 shapes (unities are ignored)
* for example:
* if arr1{2,3} and arr2{2,4,3,7} then return true and sameDims contains {0,2}
* if arr1{1,1,3,1,3,1,1} and arr2{1,2,3,1,3} then return true and sameDims contains {2,4}
* if arr1{2,1,4,1,7,5} and arr2{1,1,4,5} then return true and sameDims contains {2,5}
static bool isSubArrayCase(const NDArray& arr1, const NDArray& arr2, std::vector<int>& sameDims);
*/
2019-06-06 14:21:15 +02:00
};
//////////////////////////////////////////////////////////////////////////
///// IMLEMENTATION OF INLINE METHODS /////
//////////////////////////////////////////////////////////////////////////
FORCEINLINE bool ShapeUtils::isPermutNecessary(const std::vector<int>& permut) {
for(int i=0; i<permut.size(); ++i)
if(permut[i] != i)
return true;
return false;
}
}
#endif //LIBND4J_SHAPEUTILS_H