Update parameterized python tests, move python tests to proper package
parent
13cae7fb60
commit
224f18a586
|
@ -1,31 +1,31 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*
|
||||
* * ******************************************************************************
|
||||
* * *
|
||||
* * *
|
||||
* * * This program and the accompanying materials are made available under the
|
||||
* * * terms of the Apache License, Version 2.0 which is available at
|
||||
* * * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* * *
|
||||
* * * See the NOTICE file distributed with this work for additional
|
||||
* * * information regarding copyright ownership.
|
||||
* * * Unless required by applicable law or agreed to in writing, software
|
||||
* * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * * License for the specific language governing permissions and limitations
|
||||
* * * under the License.
|
||||
* * *
|
||||
* * * SPDX-License-Identifier: Apache-2.0
|
||||
* * *****************************************************************************
|
||||
*
|
||||
*
|
||||
*/
|
||||
|
||||
package org.nd4j.python4j;
|
||||
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.Arguments;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.nd4j.python4j.*;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
@ -35,7 +35,6 @@ import javax.annotation.concurrent.NotThreadSafe;
|
|||
import java.lang.reflect.Method;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
|
@ -43,6 +42,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
|
|||
|
||||
@NotThreadSafe
|
||||
public class PythonNumpyBasicTest {
|
||||
|
||||
public static Stream<Arguments> params() {
|
||||
DataType[] types = new DataType[] {
|
||||
DataType.BOOL,
|
||||
|
@ -61,9 +61,9 @@ public class PythonNumpyBasicTest {
|
|||
};
|
||||
|
||||
long[][] shapes = new long[][]{
|
||||
new long[]{2, 3},
|
||||
new long[]{3},
|
||||
new long[]{1},
|
||||
new long[]{2, 3},
|
||||
new long[]{3},
|
||||
new long[]{1},
|
||||
new long[]{} // scalar
|
||||
};
|
||||
|
||||
|
@ -78,23 +78,23 @@ public class PythonNumpyBasicTest {
|
|||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource("#params")
|
||||
public void testConversion(DataType dataType,long[] shape){
|
||||
try(PythonGIL pythonGIL = PythonGIL.lock()) {
|
||||
INDArray arr = Nd4j.zeros(dataType, shape);
|
||||
PythonObject npArr = PythonTypes.convert(arr);
|
||||
INDArray arr2 = PythonTypes.<INDArray>getPythonTypeForPythonObject(npArr).toJava(npArr);
|
||||
if (dataType == DataType.BFLOAT16){
|
||||
arr = arr.castTo(DataType.FLOAT);
|
||||
}
|
||||
assertEquals(arr,arr2);
|
||||
}
|
||||
@MethodSource("org.nd4j.python4j.PythonNumpyBasicTest#params")
|
||||
public void testConversion(DataType dataType,long[] shape) {
|
||||
try(PythonGIL pythonGIL = PythonGIL.lock()) {
|
||||
INDArray arr = Nd4j.zeros(dataType, shape);
|
||||
PythonObject npArr = PythonTypes.convert(arr);
|
||||
INDArray arr2 = PythonTypes.<INDArray>getPythonTypeForPythonObject(npArr).toJava(npArr);
|
||||
if (dataType == DataType.BFLOAT16){
|
||||
arr = arr.castTo(DataType.FLOAT);
|
||||
}
|
||||
assertEquals(arr,arr2);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource("#params")
|
||||
@MethodSource("org.nd4j.python4j.PythonNumpyBasicTest#params")
|
||||
public void testExecution(DataType dataType,long[] shape) {
|
||||
try(PythonGIL pythonGIL = PythonGIL.lock()) {
|
||||
List<PythonVariable> inputs = new ArrayList<>();
|
||||
|
@ -124,7 +124,7 @@ public class PythonNumpyBasicTest {
|
|||
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource("#params")
|
||||
@MethodSource("org.nd4j.python4j.PythonNumpyBasicTest#params")
|
||||
public void testInplaceExecution(DataType dataType,long[] shape) {
|
||||
try(PythonGIL pythonGIL = PythonGIL.lock()) {
|
||||
if (dataType == DataType.BOOL || dataType == DataType.BFLOAT16)return;
|
|
@ -1,4 +1,27 @@
|
|||
/*
|
||||
*
|
||||
* * ******************************************************************************
|
||||
* * *
|
||||
* * *
|
||||
* * * This program and the accompanying materials are made available under the
|
||||
* * * terms of the Apache License, Version 2.0 which is available at
|
||||
* * * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* * *
|
||||
* * * See the NOTICE file distributed with this work for additional
|
||||
* * * information regarding copyright ownership.
|
||||
* * * Unless required by applicable law or agreed to in writing, software
|
||||
* * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * * License for the specific language governing permissions and limitations
|
||||
* * * under the License.
|
||||
* * *
|
||||
* * * SPDX-License-Identifier: Apache-2.0
|
||||
* * *****************************************************************************
|
||||
*
|
||||
*
|
||||
*/
|
||||
|
||||
package org.nd4j.python4j;/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
|
@ -61,8 +84,7 @@ public class PythonNumpyCollectionsTest {
|
|||
}).stream().map(Arguments::of);
|
||||
}
|
||||
|
||||
@Test
|
||||
@MethodSource("#params")
|
||||
@MethodSource("org.nd4j.python4j.PythonNumpyCollectionsTest#params")
|
||||
@ParameterizedTest
|
||||
public void testPythonDictFromMap(DataType dataType) throws PythonException {
|
||||
try(PythonGIL pythonGIL = PythonGIL.lock()) {
|
||||
|
@ -84,8 +106,7 @@ public class PythonNumpyCollectionsTest {
|
|||
|
||||
}
|
||||
|
||||
@Test
|
||||
@MethodSource("#params")
|
||||
@MethodSource("org.nd4j.python4j.PythonNumpyCollectionsTest#params")
|
||||
@ParameterizedTest
|
||||
public void testPythonListFromList(DataType dataType) throws PythonException {
|
||||
try(PythonGIL pythonGIL = PythonGIL.lock()) {
|
|
@ -1,4 +1,27 @@
|
|||
/*
|
||||
*
|
||||
* * ******************************************************************************
|
||||
* * *
|
||||
* * *
|
||||
* * * This program and the accompanying materials are made available under the
|
||||
* * * terms of the Apache License, Version 2.0 which is available at
|
||||
* * * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* * *
|
||||
* * * See the NOTICE file distributed with this work for additional
|
||||
* * * information regarding copyright ownership.
|
||||
* * * Unless required by applicable law or agreed to in writing, software
|
||||
* * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * * License for the specific language governing permissions and limitations
|
||||
* * * under the License.
|
||||
* * *
|
||||
* * * SPDX-License-Identifier: Apache-2.0
|
||||
* * *****************************************************************************
|
||||
*
|
||||
*
|
||||
*/
|
||||
|
||||
package org.nd4j.python4j;/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
|
@ -18,11 +41,6 @@
|
|||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
import org.nd4j.python4j.Python;
|
||||
import org.nd4j.python4j.PythonGC;
|
||||
import org.nd4j.python4j.PythonGIL;
|
||||
import org.nd4j.python4j.PythonObject;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
|
@ -1,4 +1,27 @@
|
|||
/*
|
||||
*
|
||||
* * ******************************************************************************
|
||||
* * *
|
||||
* * *
|
||||
* * * This program and the accompanying materials are made available under the
|
||||
* * * terms of the Apache License, Version 2.0 which is available at
|
||||
* * * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* * *
|
||||
* * * See the NOTICE file distributed with this work for additional
|
||||
* * * information regarding copyright ownership.
|
||||
* * * Unless required by applicable law or agreed to in writing, software
|
||||
* * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * * License for the specific language governing permissions and limitations
|
||||
* * * under the License.
|
||||
* * *
|
||||
* * * SPDX-License-Identifier: Apache-2.0
|
||||
* * *****************************************************************************
|
||||
*
|
||||
*
|
||||
*/
|
||||
|
||||
package org.nd4j.python4j;/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
|
@ -18,8 +41,6 @@
|
|||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
import org.nd4j.python4j.*;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
@ -1,4 +1,27 @@
|
|||
/*
|
||||
*
|
||||
* * ******************************************************************************
|
||||
* * *
|
||||
* * *
|
||||
* * * This program and the accompanying materials are made available under the
|
||||
* * * terms of the Apache License, Version 2.0 which is available at
|
||||
* * * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* * *
|
||||
* * * See the NOTICE file distributed with this work for additional
|
||||
* * * information regarding copyright ownership.
|
||||
* * * Unless required by applicable law or agreed to in writing, software
|
||||
* * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * * License for the specific language governing permissions and limitations
|
||||
* * * under the License.
|
||||
* * *
|
||||
* * * SPDX-License-Identifier: Apache-2.0
|
||||
* * *****************************************************************************
|
||||
*
|
||||
*
|
||||
*/
|
||||
|
||||
package org.nd4j.python4j;/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
|
@ -61,8 +84,7 @@ public class PythonNumpyMultiThreadTest {
|
|||
}
|
||||
|
||||
|
||||
@Test
|
||||
@MethodSource("#params")
|
||||
@MethodSource("org.nd4j.python4j.PythonNumpyMultiThreadTest#params")
|
||||
@ParameterizedTest
|
||||
public void testMultiThreading1(DataType dataType) throws Throwable {
|
||||
final List<Throwable> exceptions = Collections.synchronizedList(new ArrayList<Throwable>());
|
||||
|
@ -100,8 +122,7 @@ public class PythonNumpyMultiThreadTest {
|
|||
|
||||
}
|
||||
|
||||
@Test
|
||||
@MethodSource("#params")
|
||||
@MethodSource("org.nd4j.python4j.PythonNumpyMultiThreadTest#params")
|
||||
@ParameterizedTest
|
||||
public void testMultiThreading2(DataType dataType) throws Throwable {
|
||||
final List<Throwable> exceptions = Collections.synchronizedList(new ArrayList<>());
|
|
@ -1,4 +1,27 @@
|
|||
/*
|
||||
*
|
||||
* * ******************************************************************************
|
||||
* * *
|
||||
* * *
|
||||
* * * This program and the accompanying materials are made available under the
|
||||
* * * terms of the Apache License, Version 2.0 which is available at
|
||||
* * * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* * *
|
||||
* * * See the NOTICE file distributed with this work for additional
|
||||
* * * information regarding copyright ownership.
|
||||
* * * Unless required by applicable law or agreed to in writing, software
|
||||
* * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * * License for the specific language governing permissions and limitations
|
||||
* * * under the License.
|
||||
* * *
|
||||
* * * SPDX-License-Identifier: Apache-2.0
|
||||
* * *****************************************************************************
|
||||
*
|
||||
*
|
||||
*/
|
||||
|
||||
package org.nd4j.python4j;/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
Loading…
Reference in New Issue