cavis/nd4s/src/main/scala/org/nd4s/ops/FilterOps.scala

66 lines
2.3 KiB
Scala
Raw Normal View History

2019-06-06 15:21:15 +03:00
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4s.ops
import org.nd4j.autodiff.samediff.SDVariable
import org.nd4j.linalg.api.ndarray.INDArray
import org.nd4j.linalg.api.ops.BaseScalarOp
import org.nd4s.Implicits._
object FilterOps {
def apply(x: INDArray, f: Double => Boolean): FilterOps =
new FilterOps(x, x.length().toInt, f)
}
class FilterOps(_x: INDArray, len: Int, f: Double => Boolean)
extends BaseScalarOp(_x, null: INDArray, _x, 0)
with LeftAssociativeBinaryOp {
def this() {
this(0.toScalar, 0, null)
}
x = _x
override def opNum(): Int = -1
override def opName(): String = "filter_scalar"
override def onnxName(): String = throw new UnsupportedOperationException
override def tensorflowName(): String =
throw new UnsupportedOperationException
override def doDiff(f1: java.util.List[SDVariable]): java.util.List[SDVariable] =
throw new UnsupportedOperationException
// override def opForDimension(index: Int, dimension: Int): Op = FilterOps(x.tensorAlongDimension(index,dimension),f,g)
//
// override def opForDimension(index: Int, dimension: Int*): Op = FilterOps(x.tensorAlongDimension(index,dimension:_*),f,g)
override def op(origin: Double): Double = if (f(origin)) origin else 0
override def op(origin: Float): Float = if (f(origin)) origin else 0
[WIP] nd4s - data types (#51) * Fixed tests Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Added conversions for Long Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Added conversions for Long Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Added data types Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Failing test Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Types in conversions Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Added mixins for integer types Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Conversion of different types to scalar Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Tests added Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Tests for arrays construction Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Construction tests Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Fixed slicing Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Add own Executioner implementation to nd4s Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Filter operation activated * Collection tests activated Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Types in operations Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Commented unused code Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Types in operations Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * String implicit conversion added Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * String implicit conversion added Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
2019-07-12 18:18:39 +03:00
override def op(origin: Short): Short = if (f(origin)) origin else 0
override def op(origin: Int): Int = if (f(origin)) origin else 0
override def op(origin: Long): Long = if (f(origin)) origin else 0
override def op(origin: String): String = ???
2019-06-06 15:21:15 +03:00
}