120 lines
4.0 KiB
Java
120 lines
4.0 KiB
Java
/*
|
|
* ******************************************************************************
|
|
* *
|
|
* *
|
|
* * This program and the accompanying materials are made available under the
|
|
* * terms of the Apache License, Version 2.0 which is available at
|
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
* *
|
|
* * See the NOTICE file distributed with this work for additional
|
|
* * information regarding copyright ownership.
|
|
* * Unless required by applicable law or agreed to in writing, software
|
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
* * License for the specific language governing permissions and limitations
|
|
* * under the License.
|
|
* *
|
|
* * SPDX-License-Identifier: Apache-2.0
|
|
* *****************************************************************************
|
|
*/
|
|
|
|
package org.nd4j.linalg.schedule;
|
|
|
|
import lombok.Data;
|
|
import lombok.EqualsAndHashCode;
|
|
import lombok.NonNull;
|
|
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
|
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
|
|
|
import java.util.Arrays;
|
|
import java.util.HashMap;
|
|
import java.util.Map;
|
|
|
|
@Data
|
|
@EqualsAndHashCode
|
|
@JsonIgnoreProperties({"allKeysSorted"})
|
|
public class MapSchedule implements ISchedule {
|
|
|
|
private ScheduleType scheduleType;
|
|
private Map<Integer, Double> values;
|
|
|
|
private int[] allKeysSorted;
|
|
|
|
public MapSchedule(@JsonProperty("scheduleType") @NonNull ScheduleType scheduleType,
|
|
@JsonProperty("values") @NonNull Map<Integer, Double> values) {
|
|
if (!values.containsKey(0)) {
|
|
throw new IllegalArgumentException("Invalid set of values: must contain initial value (position 0)");
|
|
}
|
|
this.scheduleType = scheduleType;
|
|
this.values = values;
|
|
|
|
this.allKeysSorted = new int[values.size()];
|
|
int pos = 0;
|
|
for (Integer i : values.keySet()) {
|
|
allKeysSorted[pos++] = i;
|
|
}
|
|
Arrays.sort(allKeysSorted);
|
|
}
|
|
|
|
@Override
|
|
public double valueAt(int iteration, int epoch) {
|
|
int i = (scheduleType == ScheduleType.ITERATION ? iteration : epoch);
|
|
|
|
if (values.containsKey(i)) {
|
|
return values.get(i);
|
|
} else {
|
|
//Key doesn't exist - find nearest key...
|
|
if (i >= allKeysSorted[allKeysSorted.length - 1]) {
|
|
return values.get(allKeysSorted[allKeysSorted.length - 1]);
|
|
} else {
|
|
/*
|
|
Returned:
|
|
index of the search key, if it is contained in the array; otherwise, (-(insertion point) - 1). The
|
|
insertion point is defined as the point at which the key would be inserted into the array: the index
|
|
of the first element greater than the key
|
|
*/
|
|
int pt = Arrays.binarySearch(allKeysSorted, i);
|
|
int iPt = -(pt + 1);
|
|
double d = values.get(allKeysSorted[iPt-1]);
|
|
return d;
|
|
}
|
|
}
|
|
}
|
|
|
|
@Override
|
|
public ISchedule clone() {
|
|
return new MapSchedule(scheduleType, values);
|
|
}
|
|
|
|
/**
|
|
* DynamicCustomOpsBuilder for conveniently constructing map schedules
|
|
*/
|
|
public static class Builder {
|
|
|
|
private final ScheduleType scheduleType;
|
|
private final Map<Integer, Double> values = new HashMap<>();
|
|
|
|
/**
|
|
* @param scheduleType Schedule opType to use
|
|
*/
|
|
public Builder(ScheduleType scheduleType) {
|
|
this.scheduleType = scheduleType;
|
|
}
|
|
|
|
/**
|
|
* Add a single point to the map schedule. Indexes start at 0
|
|
*
|
|
* @param position Position to add (iteration or epoch index, depending on setting)
|
|
* @param value Value for that iteraiton/epoch
|
|
*/
|
|
public Builder add(int position, double value) {
|
|
values.put(position, value);
|
|
return this;
|
|
}
|
|
|
|
public MapSchedule build() {
|
|
return new MapSchedule(scheduleType, values);
|
|
}
|
|
}
|
|
}
|