/******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ // // This class describes collection of NDArrays // // @author raver119!gmail.com // #ifndef NDARRAY_LIST_H #define NDARRAY_LIST_H #include #include #include #include #include #include namespace nd4j { class ND4J_EXPORT NDArrayList { private: // workspace where chunks belong to //nd4j::memory::Workspace* _workspace = nullptr; nd4j::LaunchContext * _context = nd4j::LaunchContext ::defaultContext(); // numeric and symbolic ids of this list std::pair _id; std::string _name; nd4j::DataType _dtype; // stored chunks std::map _chunks; // just a counter, for stored elements std::atomic _elements; std::atomic _counter; // reference shape std::vector _shape; // unstack axis int _axis = 0; // bool _expandable = false; // maximum number of elements int _height = 0; public: NDArrayList(int height, bool expandable = false); ~NDArrayList(); nd4j::DataType dataType(); NDArray* read(int idx); NDArray* readRaw(int idx); Nd4jStatus write(int idx, NDArray* array); NDArray* pick(std::initializer_list indices); NDArray* pick(std::vector& indices); bool isWritten(int index); std::vector& shape(); NDArray* stack(); void unstack(NDArray* array, int axis); std::pair& id(); std::string& name(); //nd4j::memory::Workspace* workspace(); nd4j::LaunchContext * context(); NDArrayList* clone(); bool equals(NDArrayList& other); int elements(); int height(); int counter(); }; } #endif