parent
e51e6ebfd2
commit
e7730eded4
|
@ -1,4 +1,4 @@
|
|||
/*******************************************************************************
|
||||
/* *****************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
|
@ -17,7 +17,6 @@
|
|||
package org.nd4j.list;
|
||||
|
||||
import lombok.val;
|
||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -35,35 +34,15 @@ import java.util.*;
|
|||
*
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
@SuppressWarnings("unchecked") //too many of them.
|
||||
public abstract class BaseNDArrayList<X extends Number> extends AbstractList<X> {
|
||||
protected INDArray container;
|
||||
protected int size;
|
||||
|
||||
|
||||
|
||||
public BaseNDArrayList() {
|
||||
BaseNDArrayList() {
|
||||
this.container = Nd4j.create(10);
|
||||
}
|
||||
|
||||
/**
|
||||
* Specify the underlying ndarray for this list.
|
||||
* @param container the underlying array.
|
||||
*/
|
||||
public BaseNDArrayList(INDArray container) {
|
||||
this.container = container;
|
||||
}
|
||||
|
||||
/**
|
||||
* Allocates the container and this list with
|
||||
* the given size
|
||||
* @param size the size to allocate with
|
||||
*/
|
||||
public void allocateWithSize(int size) {
|
||||
container = Nd4j.create(1,size);
|
||||
this.size = size;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Get a view of the underlying array
|
||||
* relative to the size of the actual array.
|
||||
|
@ -321,11 +300,11 @@ public abstract class BaseNDArrayList<X extends Number> extends AbstractList<X
|
|||
private class NDArrayListIterator implements ListIterator<X> {
|
||||
private int curr = 0;
|
||||
|
||||
public NDArrayListIterator(int curr) {
|
||||
NDArrayListIterator(int curr) {
|
||||
this.curr = curr;
|
||||
}
|
||||
|
||||
public NDArrayListIterator() {
|
||||
NDArrayListIterator() {
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -335,9 +314,9 @@ public abstract class BaseNDArrayList<X extends Number> extends AbstractList<X
|
|||
|
||||
@Override
|
||||
public X next() {
|
||||
Number ret = get(curr);
|
||||
X ret = get(curr);
|
||||
curr++;
|
||||
return (X) ret;
|
||||
return ret;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -347,9 +326,9 @@ public abstract class BaseNDArrayList<X extends Number> extends AbstractList<X
|
|||
|
||||
@Override
|
||||
public X previous() {
|
||||
Number ret = get(curr - 1);
|
||||
X ret = get(curr - 1);
|
||||
curr--;
|
||||
return (X) ret;
|
||||
return ret;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -1,39 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.list;
|
||||
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
/**
|
||||
* An {@link BaseNDArrayList} for float
|
||||
*
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
public class FloatNDArrayList extends BaseNDArrayList<Float> {
|
||||
public FloatNDArrayList() {
|
||||
}
|
||||
|
||||
public FloatNDArrayList(INDArray container) {
|
||||
super(container);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Float get(int i) {
|
||||
Number ret = container.getDouble(i);
|
||||
return ret.floatValue();
|
||||
}
|
||||
}
|
|
@ -1,42 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.list;
|
||||
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
/**
|
||||
* An {@link BaseNDArrayList} for integers
|
||||
*
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
public class IntNDArrayList extends BaseNDArrayList<Integer> {
|
||||
public IntNDArrayList() {
|
||||
}
|
||||
|
||||
public IntNDArrayList(INDArray container) {
|
||||
super(container);
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public Integer get(int i) {
|
||||
Number ret = container.getDouble(i);
|
||||
return ret.intValue();
|
||||
}
|
||||
|
||||
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
/*******************************************************************************
|
||||
/* *****************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
|
@ -17,7 +17,6 @@
|
|||
package org.nd4j.list;
|
||||
|
||||
import lombok.NonNull;
|
||||
import lombok.val;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||
|
@ -272,11 +271,11 @@ public class NDArrayList extends BaseNDArrayList<Double> {
|
|||
private class NDArrayListIterator implements ListIterator<Double> {
|
||||
private int curr = 0;
|
||||
|
||||
public NDArrayListIterator(int curr) {
|
||||
NDArrayListIterator(int curr) {
|
||||
this.curr = curr;
|
||||
}
|
||||
|
||||
public NDArrayListIterator() {
|
||||
NDArrayListIterator() {
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -1,31 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.list.matrix;
|
||||
|
||||
import org.nd4j.list.FloatNDArrayList;
|
||||
|
||||
/**
|
||||
* A {@link MatrixBaseNDArrayList}
|
||||
* for float data type
|
||||
*
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
public class FloatMatrixNDArrayList extends MatrixBaseNDArrayList<FloatNDArrayList> {
|
||||
public FloatMatrixNDArrayList() {
|
||||
}
|
||||
|
||||
}
|
|
@ -1,31 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.list.matrix;
|
||||
|
||||
import org.nd4j.list.IntNDArrayList;
|
||||
|
||||
/**
|
||||
* A {@link MatrixBaseNDArrayList}
|
||||
* for int data type
|
||||
*
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
public class IntMatrixNDArrayList extends MatrixBaseNDArrayList<IntNDArrayList> {
|
||||
public IntMatrixNDArrayList() {
|
||||
}
|
||||
|
||||
}
|
|
@ -1,184 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.list.matrix;
|
||||
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.list.BaseNDArrayList;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
/**
|
||||
* An {@link ArrayList} like implementation of {@link List}
|
||||
* using {@link INDArray} as the backing data structure.
|
||||
*
|
||||
* Creates an internal container of ndarray lists with a matrix shape.
|
||||
*
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
public abstract class MatrixBaseNDArrayList<X extends BaseNDArrayList> extends AbstractList<X> {
|
||||
private List<X> list = new ArrayList<>();
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* Get a view of the underlying array
|
||||
* relative to the size of the actual array.
|
||||
* (Sometimes there are overflows in the internals
|
||||
* but you want to use the internal INDArray for computing something
|
||||
* directly, this gives you the relevant subset that reflects the content of the list)
|
||||
* @return the view of the underlying ndarray relative to the collection's real size
|
||||
*/
|
||||
public INDArray array() {
|
||||
List<INDArray> retList = new ArrayList<>(list.size());
|
||||
for(X x : list) {
|
||||
INDArray arr = x.array();
|
||||
retList.add(arr.reshape(1, arr.length()));
|
||||
}
|
||||
|
||||
return Nd4j.concat(0,retList.toArray(new INDArray[retList.size()]));
|
||||
}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return list.size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isEmpty() {
|
||||
return list.isEmpty();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean contains(Object o) {
|
||||
return list.contains(o);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Iterator<X> iterator() {
|
||||
return list.iterator();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object[] toArray() {
|
||||
return list.toArray();
|
||||
}
|
||||
|
||||
@Override
|
||||
public <T> T[] toArray(T[] ts) {
|
||||
return list.toArray(ts);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean add(X aX) {
|
||||
return list.add(aX);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean remove(Object o) {
|
||||
return list.remove(o);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean containsAll(Collection<?> collection) {
|
||||
return list.containsAll(collection);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean addAll(Collection<? extends X> collection) {
|
||||
return list.addAll(collection);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean addAll(int i, Collection<? extends X> collection) {
|
||||
return list.addAll(collection);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean removeAll(Collection<?> collection) {
|
||||
return list.removeAll(collection);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean retainAll(Collection<?> collection) {
|
||||
return list.retainAll(collection);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void clear() {
|
||||
list.clear();
|
||||
}
|
||||
|
||||
@Override
|
||||
public X get(int i) {
|
||||
return list.get(i);
|
||||
}
|
||||
|
||||
@Override
|
||||
public X set(int i, X aX) {
|
||||
return list.set(i,aX);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void add(int i, X aX) {
|
||||
list.add(i,aX);
|
||||
}
|
||||
|
||||
@Override
|
||||
public X remove(int i) {
|
||||
return list.remove(i);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int indexOf(Object o) {
|
||||
return list.indexOf(o);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int lastIndexOf(Object o) {
|
||||
return list.lastIndexOf(o);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ListIterator<X> listIterator() {
|
||||
return list.listIterator();
|
||||
}
|
||||
|
||||
@Override
|
||||
public ListIterator<X> listIterator(int i) {
|
||||
return list.listIterator(i);
|
||||
}
|
||||
|
||||
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return list.toString();
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Get entry i,j in the matrix
|
||||
* @param i the row
|
||||
* @param j the column
|
||||
* @return the entry at i,j if it exists
|
||||
*/
|
||||
public Number getEntry(int i,int j) {
|
||||
return list.get(i).get(j);
|
||||
}
|
||||
|
||||
}
|
|
@ -1,32 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.list.matrix;
|
||||
|
||||
import org.nd4j.list.NDArrayList;
|
||||
|
||||
/**
|
||||
* A {@link MatrixBaseNDArrayList}
|
||||
* for double data type
|
||||
*
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
public class MatrixNDArrayList extends MatrixBaseNDArrayList<NDArrayList> {
|
||||
public MatrixNDArrayList() {
|
||||
}
|
||||
|
||||
|
||||
}
|
|
@ -18,15 +18,12 @@ package org.nd4j.list;
|
|||
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||
import org.nd4j.list.matrix.MatrixNDArrayList;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertFalse;
|
||||
|
||||
public class NDArrayListTest extends BaseNd4jTest {
|
||||
|
||||
|
@ -71,24 +68,4 @@ public class NDArrayListTest extends BaseNd4jTest {
|
|||
assertEquals(ndArrayList.size(),ndArrayList.array().length());
|
||||
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testMatrixList() {
|
||||
MatrixNDArrayList matrixNDArrayList = new MatrixNDArrayList();
|
||||
for(int i = 0; i < 5; i++) {
|
||||
NDArrayList ndArrayList = new NDArrayList();
|
||||
for(int j = 0; j < 4; j++) {
|
||||
ndArrayList.add((double) j);
|
||||
}
|
||||
|
||||
matrixNDArrayList.add(ndArrayList);
|
||||
}
|
||||
|
||||
INDArray arr = matrixNDArrayList.array();
|
||||
assertEquals(5,arr.rows());
|
||||
assertFalse(matrixNDArrayList.isEmpty());
|
||||
assertEquals(0.0,matrixNDArrayList.getEntry(0,0));
|
||||
}
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue