├── .gitignore ├── src ├── main │ └── java │ │ └── com │ │ └── hljunlp │ │ └── laozhongyi │ │ ├── process │ │ ├── ParamsAndCallable.java │ │ └── ShellProcess.java │ │ ├── strategy │ │ ├── Strategy.java │ │ ├── BaseStrategy.java │ │ ├── SimulatedAnnealingStrategy.java │ │ └── TraiditionalSimulatedAnnealingStrategy.java │ │ ├── HyperParameterScopeItem.java │ │ ├── Utils.java │ │ ├── HyperParamResultManager.java │ │ ├── HyperParameterScopeConfigReader.java │ │ ├── checkpoint │ │ ├── SimulatedAnnealingCheckPointManager.java │ │ ├── SimulatedAnnealingCheckPointData.java │ │ ├── CheckPointData.java │ │ └── CheckPointManager.java │ │ ├── HyperParameterConfig.java │ │ ├── GeneratedFileManager.java │ │ └── Laozhongyi.java └── test │ └── java │ └── com │ └── hljunlp │ └── laozhongyi │ └── strategy │ ├── StrategyTest.java │ └── checkpoint │ └── SimulatedAnnealingCheckPointManagerTest.java ├── LICENSE ├── README.md └── pom.xml /.gitignore: -------------------------------------------------------------------------------- 1 | target/ 2 | pom.xml.tag 3 | pom.xml.releaseBackup 4 | pom.xml.versionsBackup 5 | pom.xml.next 6 | release.properties 7 | dependency-reduced-pom.xml 8 | buildNumber.properties 9 | .mvn/timing.properties 10 | .classpath 11 | .project 12 | .settings/ 13 | 14 | # Avoid ignoring Maven wrapper jar file (.jar files are usually ignored) 15 | !/.mvn/wrapper/maven-wrapper.jar 16 | .idea 17 | -------------------------------------------------------------------------------- /src/main/java/com/hljunlp/laozhongyi/process/ParamsAndCallable.java: -------------------------------------------------------------------------------- 1 | package com.hljunlp.laozhongyi.process; 2 | 3 | import java.util.Map; 4 | 5 | public class ParamsAndCallable { 6 | private final Map mParams; 7 | private final ShellProcess mCallable; 8 | 9 | public ParamsAndCallable(final Map params, final ShellProcess callable) { 10 | mParams = params; 11 | mCallable = callable; 12 | } 13 | 14 | public Map getParams() { 15 | return mParams; 16 | } 17 | 18 | public ShellProcess getCallable() { 19 | return mCallable; 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /src/main/java/com/hljunlp/laozhongyi/strategy/Strategy.java: -------------------------------------------------------------------------------- 1 | package com.hljunlp.laozhongyi.strategy; 2 | 3 | import java.util.List; 4 | 5 | public interface Strategy { 6 | int chooseSuitableIndex(List results, int originalIndex, boolean isFirstTry); 7 | 8 | default void iterationEnd() { 9 | } 10 | 11 | default boolean ensureIfStop(final boolean shouldStop) { 12 | return shouldStop; 13 | } 14 | 15 | default void storeBest() { 16 | } 17 | 18 | default void restoreBest() { 19 | } 20 | 21 | default void saveCheckPoint() { 22 | } 23 | 24 | default void restoreCheckPoint() { 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /src/main/java/com/hljunlp/laozhongyi/strategy/BaseStrategy.java: -------------------------------------------------------------------------------- 1 | package com.hljunlp.laozhongyi.strategy; 2 | 3 | import java.util.List; 4 | 5 | import com.google.common.base.Preconditions; 6 | 7 | public class BaseStrategy implements Strategy { 8 | @Override 9 | public int chooseSuitableIndex(final List results, final int originalIndex, final boolean isFirstTry) { 10 | float best = -1000; 11 | int index = -1; 12 | int i = -1; 13 | for (final float result : results) { 14 | ++i; 15 | if (result > best) { 16 | best = result; 17 | index = i; 18 | } 19 | } 20 | Preconditions.checkState(best != -1000); 21 | return index; 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /src/test/java/com/hljunlp/laozhongyi/strategy/StrategyTest.java: -------------------------------------------------------------------------------- 1 | package com.hljunlp.laozhongyi.strategy; 2 | 3 | import java.util.List; 4 | 5 | import org.junit.Test; 6 | 7 | import com.google.common.collect.Lists; 8 | 9 | public class StrategyTest { 10 | @Test 11 | public void testTraditionalSa() { 12 | final TraiditionalSimulatedAnnealingStrategy strategy = new TraiditionalSimulatedAnnealingStrategy( 13 | (float) 0.9, 1); 14 | final List results = Lists.newArrayList(0.5f, 0.49f); 15 | for (int i = 0; i < 100; ++i) { 16 | final int index = strategy.chooseSuitableIndex(results, 0, false); 17 | System.out.println(index); 18 | strategy.iterationEnd(); 19 | } 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /src/main/java/com/hljunlp/laozhongyi/HyperParameterScopeItem.java: -------------------------------------------------------------------------------- 1 | package com.hljunlp.laozhongyi; 2 | 3 | import java.util.List; 4 | 5 | import org.apache.commons.lang3.builder.ToStringBuilder; 6 | 7 | public class HyperParameterScopeItem { 8 | private final String mKey; 9 | private final List mValues; 10 | 11 | public HyperParameterScopeItem(final String key, final List values) { 12 | super(); 13 | this.mKey = key; 14 | this.mValues = values; 15 | } 16 | 17 | public String getKey() { 18 | return mKey; 19 | } 20 | 21 | public List getValues() { 22 | return mValues; 23 | } 24 | 25 | @Override 26 | public String toString() { 27 | return ToStringBuilder.reflectionToString(this); 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /src/main/java/com/hljunlp/laozhongyi/Utils.java: -------------------------------------------------------------------------------- 1 | package com.hljunlp.laozhongyi; 2 | 3 | import java.util.Map; 4 | import java.util.regex.Matcher; 5 | import java.util.regex.Pattern; 6 | 7 | import com.google.common.collect.Maps; 8 | 9 | public class Utils { 10 | public static Map modifiedNewMap(final Map map, 11 | final String key, final String value) { 12 | final Map copied = Maps.newTreeMap(); 13 | for (final Map.Entry entry : map.entrySet()) { 14 | copied.put(entry.getKey(), entry.getValue()); 15 | } 16 | copied.put(key, value); 17 | return copied; 18 | } 19 | 20 | public static float logResult(final String log) { 21 | final Pattern pattern = Pattern.compile("laozhongyi_([\\d\\.]+)"); 22 | final Matcher matcher = pattern.matcher(log); 23 | float result = 0.0f; 24 | while (matcher.find()) { 25 | final String string = matcher.group(1); 26 | result = Float.valueOf(string); 27 | } 28 | return result; 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Chauncey Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/main/java/com/hljunlp/laozhongyi/HyperParamResultManager.java: -------------------------------------------------------------------------------- 1 | package com.hljunlp.laozhongyi; 2 | 3 | import com.google.common.collect.Maps; 4 | 5 | import java.util.HashMap; 6 | import java.util.Map; 7 | import java.util.Optional; 8 | 9 | public class HyperParamResultManager { 10 | private static final Map, Float> mResults = Maps.newHashMap(); 11 | 12 | public static synchronized Optional getResult(final Map hyperParams) { 13 | return Optional.ofNullable(mResults.get(hyperParams)); 14 | } 15 | 16 | public static synchronized void putResult(final Map hyperParams, 17 | final float result) { 18 | mResults.put(hyperParams, result); 19 | } 20 | 21 | public static Map, Float> deepCopyResults() { 22 | Map, Float> copy = new HashMap<>(); 23 | synchronized (HyperParamResultManager.class) { 24 | for (Map.Entry, Float> entry : mResults.entrySet()) { 25 | Map hyperParamsCopy = new HashMap<>(entry.getKey()); 26 | Float resultCopy = entry.getValue(); 27 | copy.put(hyperParamsCopy, resultCopy); 28 | } 29 | } 30 | return copy; 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /src/main/java/com/hljunlp/laozhongyi/HyperParameterScopeConfigReader.java: -------------------------------------------------------------------------------- 1 | package com.hljunlp.laozhongyi; 2 | 3 | import java.io.File; 4 | import java.io.IOException; 5 | import java.util.List; 6 | 7 | import org.apache.commons.io.Charsets; 8 | import org.apache.commons.io.FileUtils; 9 | import org.apache.commons.lang3.StringUtils; 10 | 11 | import com.google.common.collect.Lists; 12 | 13 | public class HyperParameterScopeConfigReader { 14 | public static List read(final String configFilePath) { 15 | final File file = new File(configFilePath); 16 | List lines; 17 | try { 18 | lines = FileUtils.readLines(file, Charsets.UTF_8); 19 | } catch (final IOException e) { 20 | throw new IllegalStateException(e); 21 | } 22 | final List result = Lists.newArrayList(); 23 | for (final String line : lines) { 24 | final String[] segments = StringUtils.split(line, ','); 25 | final List values = Lists.newArrayList(); 26 | for (int i = 1; i < segments.length; ++i) { 27 | values.add(segments[i]); 28 | } 29 | final HyperParameterScopeItem item = new HyperParameterScopeItem(segments[0], values); 30 | result.add(item); 31 | } 32 | return result; 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /src/main/java/com/hljunlp/laozhongyi/checkpoint/SimulatedAnnealingCheckPointManager.java: -------------------------------------------------------------------------------- 1 | package com.hljunlp.laozhongyi.checkpoint; 2 | 3 | import org.json.JSONObject; 4 | 5 | public class SimulatedAnnealingCheckPointManager extends CheckPointManager { 6 | /** 7 | * Constructor. 8 | * 9 | * @param checkPointPath The path of the checkpoint file. 10 | */ 11 | public SimulatedAnnealingCheckPointManager(final String checkPointPath) { 12 | super(checkPointPath); 13 | } 14 | 15 | @Override 16 | protected JSONObject convertToJSON(final CheckPointData data) { 17 | JSONObject jsonObject = super.convertToJSON(data); 18 | SimulatedAnnealingCheckPointData simulatedAnnealingCheckPointData = 19 | (SimulatedAnnealingCheckPointData) data; 20 | jsonObject.put("temperature", simulatedAnnealingCheckPointData.getTemperature()); 21 | jsonObject.put("decayRate", simulatedAnnealingCheckPointData.getDecayRate()); 22 | return jsonObject; 23 | } 24 | 25 | @Override 26 | protected CheckPointData convertFromJSON(final JSONObject jsonObject) { 27 | CheckPointData data = super.convertFromJSON(jsonObject); 28 | float temperature = (float) jsonObject.getDouble("temperature"); 29 | float decayRate = (float) jsonObject.getDouble("decayRate"); 30 | return new SimulatedAnnealingCheckPointData(temperature, decayRate, 31 | data.getCurrentHyperParameters(), data.getCurrentHyperParametersIndex(), 32 | data.getBestHyperParameters(), data.getBestScore(), 33 | data.getHyperParametersToScore()); 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /src/main/java/com/hljunlp/laozhongyi/strategy/SimulatedAnnealingStrategy.java: -------------------------------------------------------------------------------- 1 | package com.hljunlp.laozhongyi.strategy; 2 | 3 | import com.google.common.base.Preconditions; 4 | import com.hljunlp.laozhongyi.checkpoint.SimulatedAnnealingCheckPointData; 5 | 6 | import java.util.Random; 7 | 8 | public abstract class SimulatedAnnealingStrategy implements Strategy { 9 | protected float mT; 10 | protected final float mR; 11 | protected final Random mRandom; 12 | protected float mBestT; 13 | 14 | public SimulatedAnnealingStrategy(final float r, final float t) { 15 | Preconditions.checkArgument(r > 0 && t > 0); 16 | mT = t; 17 | mR = r; 18 | mRandom = new Random(); 19 | } 20 | 21 | public float getTemperature() { 22 | return mT; 23 | } 24 | 25 | public float getDecayRate() { 26 | return mR; 27 | } 28 | 29 | public SimulatedAnnealingStrategy(final SimulatedAnnealingCheckPointData checkPointData) { 30 | mT = checkPointData.getTemperature(); 31 | mR = checkPointData.getDecayRate(); 32 | mRandom = new Random(); 33 | } 34 | 35 | @Override 36 | public void iterationEnd() { 37 | final float next = mT * mR; 38 | mT = next >= 0.001 ? next : mT; 39 | System.out.println("SimulatedAnnealingStrategy iterationEnd mT:" + mT); 40 | } 41 | 42 | @Override 43 | public boolean ensureIfStop(final boolean shouldStop) { 44 | return shouldStop && mT * mR < 0.001; 45 | } 46 | 47 | @Override 48 | public void storeBest() { 49 | mBestT = mT; 50 | } 51 | 52 | @Override 53 | public void restoreBest() { 54 | mT = mBestT; 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /src/main/java/com/hljunlp/laozhongyi/strategy/TraiditionalSimulatedAnnealingStrategy.java: -------------------------------------------------------------------------------- 1 | package com.hljunlp.laozhongyi.strategy; 2 | 3 | import java.util.List; 4 | 5 | import com.google.common.base.Preconditions; 6 | 7 | public class TraiditionalSimulatedAnnealingStrategy extends SimulatedAnnealingStrategy { 8 | public TraiditionalSimulatedAnnealingStrategy(final float r, final float t) { 9 | super(r, t); 10 | } 11 | 12 | @Override 13 | public int chooseSuitableIndex(final List results, final int originalIndex, final boolean isFirstTry) { 14 | float max = -1; 15 | int maxi = -1; 16 | for (int i = 0; i < results.size(); ++i) { 17 | if (i == originalIndex) { 18 | continue; 19 | } 20 | if (max < results.get(i)) { 21 | max = results.get(i); 22 | maxi = i; 23 | } 24 | } 25 | 26 | Preconditions.checkState(maxi != -1); 27 | 28 | System.out.println( 29 | "TraiditionalSimulatedAnnealingStrategy chooseSuitableIndex: max result is " + max); 30 | 31 | if (isFirstTry || max > results.get(originalIndex)) { 32 | if (!isFirstTry) { 33 | System.out.println( 34 | "TraiditionalSimulatedAnnealingStrategy chooseSuitableIndex: original value is " 35 | + results.get(originalIndex)); 36 | } 37 | return maxi; 38 | } else { 39 | final float de = max - results.get(originalIndex); 40 | if (Math.abs(de) < 0.0001) { 41 | return originalIndex; 42 | } 43 | final float prob = (float) Math.exp(de / mT); 44 | System.out.println( 45 | "TraiditionalSimulatedAnnealingStrategy chooseSuitableIndex prob is " + prob); 46 | if (mRandom.nextFloat() <= prob) { 47 | return maxi; 48 | } else { 49 | return originalIndex; 50 | } 51 | } 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /src/main/java/com/hljunlp/laozhongyi/checkpoint/SimulatedAnnealingCheckPointData.java: -------------------------------------------------------------------------------- 1 | package com.hljunlp.laozhongyi.checkpoint; 2 | 3 | import java.util.Map; 4 | 5 | public class SimulatedAnnealingCheckPointData extends CheckPointData { 6 | private final float temperature; 7 | private final float decayRate; 8 | 9 | public SimulatedAnnealingCheckPointData(final float temperature, final float decayRate, 10 | final Map currentHyperParameters, 11 | final int currentHyperParametersIndex, 12 | final Map bestHyperParameters, 13 | final float bestScore, 14 | final Map, Float> hyperParametersToScore) { 15 | super(currentHyperParameters, currentHyperParametersIndex, bestHyperParameters, bestScore 16 | , hyperParametersToScore); 17 | this.temperature = temperature; 18 | this.decayRate = decayRate; 19 | } 20 | 21 | public float getTemperature() { 22 | return temperature; 23 | } 24 | 25 | public float getDecayRate() { 26 | return decayRate; 27 | } 28 | 29 | @Override 30 | public String toString() { 31 | return "SimulatedAnnealingCheckPointData{" + 32 | "temperature=" + temperature + 33 | ", decayRate=" + decayRate + 34 | "} " + super.toString(); 35 | } 36 | 37 | @Override 38 | public CheckPointData deepCopy() { 39 | CheckPointData checkPointData = super.deepCopy(); 40 | return new SimulatedAnnealingCheckPointData(temperature, decayRate, 41 | checkPointData.getCurrentHyperParameters(), 42 | checkPointData.getCurrentHyperParametersIndex(), 43 | checkPointData.getBestHyperParameters(), 44 | checkPointData.getBestScore(), 45 | checkPointData.getHyperParametersToScore()); 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /src/main/java/com/hljunlp/laozhongyi/HyperParameterConfig.java: -------------------------------------------------------------------------------- 1 | package com.hljunlp.laozhongyi; 2 | 3 | import java.io.File; 4 | import java.io.IOException; 5 | import java.util.List; 6 | import java.util.Map; 7 | import java.util.Map.Entry; 8 | 9 | import org.apache.commons.io.Charsets; 10 | import org.apache.commons.io.FileUtils; 11 | import org.apache.commons.lang3.StringUtils; 12 | 13 | import com.google.common.base.Preconditions; 14 | import com.google.common.collect.Lists; 15 | import com.google.common.collect.Maps; 16 | 17 | public class HyperParameterConfig { 18 | private final String mConfigFilePath; 19 | 20 | public HyperParameterConfig(final String configFilePath) { 21 | mConfigFilePath = configFilePath; 22 | } 23 | 24 | public void write(final Map items) { 25 | final List lines = Lists.newArrayList(); 26 | for (final Entry item : items.entrySet()) { 27 | lines.add(item.getKey() + " = " + item.getValue()); 28 | } 29 | 30 | try { 31 | FileUtils.writeLines(new File(mConfigFilePath), lines); 32 | } catch (final IOException e) { 33 | throw new IllegalStateException(e); 34 | } 35 | } 36 | 37 | public void check(final Map items) { 38 | final Map map = toMap(); 39 | Preconditions.checkState(map.size() == items.size()); 40 | for (final Entry item : items.entrySet()) { 41 | Preconditions.checkState(map.get(item.getKey()) == item.getValue()); 42 | } 43 | } 44 | 45 | private Map toMap() { 46 | final File file = new File(mConfigFilePath); 47 | List lines; 48 | try { 49 | lines = FileUtils.readLines(file, Charsets.UTF_8); 50 | } catch (final IOException e) { 51 | throw new IllegalStateException(e); 52 | } 53 | final Map map = Maps.newTreeMap(); 54 | for (final String line : lines) { 55 | final String[] segs = StringUtils.split(line, " = "); 56 | map.put(segs[0], segs[1]); 57 | } 58 | 59 | return map; 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /src/main/java/com/hljunlp/laozhongyi/GeneratedFileManager.java: -------------------------------------------------------------------------------- 1 | package com.hljunlp.laozhongyi; 2 | 3 | import java.io.File; 4 | import java.io.IOException; 5 | import java.util.Collections; 6 | import java.util.List; 7 | import java.util.Map; 8 | import java.util.Set; 9 | 10 | import org.apache.commons.io.FileUtils; 11 | import org.apache.commons.io.FilenameUtils; 12 | import org.apache.commons.lang3.StringUtils; 13 | import org.joda.time.LocalDateTime; 14 | 15 | import com.google.common.collect.Lists; 16 | 17 | public class GeneratedFileManager { 18 | private static String mLogDirPath; 19 | private static String mHyperParameterConfigDirPath; 20 | 21 | public static void mkdirForHyperParameterConfig() { 22 | final String homeDir = System.getProperty("user.home"); 23 | final String logDir = "hyper" + new LocalDateTime().toString(); 24 | mHyperParameterConfigDirPath = FilenameUtils.concat(homeDir, logDir); 25 | try { 26 | FileUtils.forceMkdir(new File(mHyperParameterConfigDirPath)); 27 | } catch (final IOException e) { 28 | throw new IllegalStateException(e); 29 | } 30 | } 31 | 32 | public static void mkdirForLog() { 33 | final String homeDir = System.getProperty("user.home"); 34 | final String logDir = "log" + new LocalDateTime().toString(); 35 | mLogDirPath = FilenameUtils.concat(homeDir, logDir); 36 | try { 37 | FileUtils.forceMkdir(new File(mLogDirPath)); 38 | } catch (final IOException e) { 39 | throw new IllegalStateException(e); 40 | } 41 | } 42 | 43 | public static String getHyperParameterConfigFileFullPath(final Map config, 44 | final Set multiValuesKeys) { 45 | final List keys = Lists.newArrayList(); 46 | for (final String key : config.keySet()) { 47 | if (multiValuesKeys.contains(key)) { 48 | keys.add(key); 49 | } 50 | } 51 | Collections.sort(keys); 52 | 53 | String fileName = StringUtils.EMPTY; 54 | for (final String key : keys) { 55 | fileName += key + config.get(key).replace("/", "-"); 56 | } 57 | 58 | return FilenameUtils.concat(mHyperParameterConfigDirPath, fileName); 59 | } 60 | 61 | public static String getLogFileFullPath(final Map config, 62 | final Set multiValuesKeys) { 63 | final List keys = Lists.newArrayList(); 64 | for (final String key : config.keySet()) { 65 | if (multiValuesKeys.contains(key)) { 66 | keys.add(key); 67 | } 68 | } 69 | Collections.sort(keys); 70 | 71 | String fileName = StringUtils.EMPTY; 72 | for (final String key : keys) { 73 | fileName += key + config.get(key).replace("/", "-"); 74 | } 75 | 76 | return FilenameUtils.concat(mLogDirPath, fileName); 77 | } 78 | } 79 | -------------------------------------------------------------------------------- /src/test/java/com/hljunlp/laozhongyi/strategy/checkpoint/SimulatedAnnealingCheckPointManagerTest.java: -------------------------------------------------------------------------------- 1 | package com.hljunlp.laozhongyi.strategy.checkpoint; 2 | 3 | import com.hljunlp.laozhongyi.checkpoint.SimulatedAnnealingCheckPointData; 4 | import com.hljunlp.laozhongyi.checkpoint.SimulatedAnnealingCheckPointManager; 5 | import org.junit.jupiter.api.Assertions; 6 | import org.junit.jupiter.api.Test; 7 | 8 | import java.util.HashMap; 9 | import java.util.Map; 10 | 11 | public class SimulatedAnnealingCheckPointManagerTest { 12 | 13 | @Test 14 | void saveAndLoadCheckPoint() { 15 | // Create a SimulatedAnnealingCheckPointData object with some values 16 | Map currentHyperParameters = new HashMap<>(); 17 | currentHyperParameters.put("param1", "value1"); 18 | currentHyperParameters.put("param2", "value2"); 19 | Map bestHyperParameters = new HashMap<>(); 20 | bestHyperParameters.put("param3", "value3"); 21 | bestHyperParameters.put("param4", "value4"); 22 | Map, Float> hyperParametersToScore = new HashMap<>(); 23 | hyperParametersToScore.put(currentHyperParameters, 0.8f); 24 | hyperParametersToScore.put(bestHyperParameters, 0.9f); 25 | for (int i = 0; i < 10; i++) { 26 | Map hyperParameters = new HashMap<>(); 27 | for (int j = 0; j < 10; j++) { 28 | hyperParameters.put("param" + j, "value" + j); 29 | } 30 | hyperParametersToScore.put(hyperParameters, 0.1f * i); 31 | } 32 | SimulatedAnnealingCheckPointData data = new SimulatedAnnealingCheckPointData(1.0f, 0.5f, 33 | currentHyperParameters, 2, bestHyperParameters, 0.9f, hyperParametersToScore); 34 | 35 | // Save the SimulatedAnnealingCheckPointData object to a file 36 | String filePath = "checkpoint.json"; 37 | SimulatedAnnealingCheckPointManager manager = 38 | new SimulatedAnnealingCheckPointManager(filePath); 39 | manager.save(data, "checkpoint-test.json"); 40 | 41 | // Load the SimulatedAnnealingCheckPointData object from the file 42 | SimulatedAnnealingCheckPointData loadedData = 43 | (SimulatedAnnealingCheckPointData) manager.load("checkpoint-test.json"); 44 | 45 | // Compare the loaded SimulatedAnnealingCheckPointData object with the original one 46 | Assertions.assertEquals(data.getTemperature(), loadedData.getTemperature()); 47 | Assertions.assertEquals(data.getDecayRate(), loadedData.getDecayRate()); 48 | Assertions.assertEquals(data.getCurrentHyperParametersIndex(), 49 | loadedData.getCurrentHyperParametersIndex()); 50 | Assertions.assertEquals(data.getBestScore(), loadedData.getBestScore()); 51 | Assertions.assertEquals(data.getCurrentHyperParameters().size(), 52 | loadedData.getCurrentHyperParameters().size()); 53 | Assertions.assertEquals(data.getBestHyperParameters().size(), 54 | loadedData.getBestHyperParameters().size()); 55 | Assertions.assertEquals(data.getHyperParametersToScore().size(), 56 | loadedData.getHyperParametersToScore().size()); 57 | for (Map.Entry entry : data.getCurrentHyperParameters().entrySet()) { 58 | Assertions.assertEquals(entry.getValue(), 59 | loadedData.getCurrentHyperParameters().get(entry.getKey())); 60 | } 61 | for (Map.Entry entry : data.getBestHyperParameters().entrySet()) { 62 | Assertions.assertEquals(entry.getValue(), 63 | loadedData.getBestHyperParameters().get(entry.getKey())); 64 | } 65 | for (Map.Entry, Float> entry : 66 | data.getHyperParametersToScore().entrySet()) { 67 | Assertions.assertEquals(entry.getValue(), 68 | loadedData.getHyperParametersToScore().get(entry.getKey())); 69 | } 70 | } 71 | 72 | @Test 73 | void loadNonExistentCheckPoint() { 74 | String filePath = "checkpoint.json"; 75 | SimulatedAnnealingCheckPointManager manager = 76 | new SimulatedAnnealingCheckPointManager(filePath); 77 | Assertions.assertThrows(RuntimeException.class, manager::load); 78 | } 79 | } 80 | -------------------------------------------------------------------------------- /src/main/java/com/hljunlp/laozhongyi/process/ShellProcess.java: -------------------------------------------------------------------------------- 1 | package com.hljunlp.laozhongyi.process; 2 | 3 | import com.google.common.base.Charsets; 4 | import com.hljunlp.laozhongyi.GeneratedFileManager; 5 | import com.hljunlp.laozhongyi.HyperParamResultManager; 6 | import com.hljunlp.laozhongyi.HyperParameterConfig; 7 | import com.hljunlp.laozhongyi.Utils; 8 | import com.hljunlp.laozhongyi.checkpoint.CheckPointManager; 9 | import org.apache.commons.exec.DefaultExecutor; 10 | import org.apache.commons.exec.ExecuteException; 11 | import org.apache.commons.exec.PumpStreamHandler; 12 | import org.apache.commons.io.FileUtils; 13 | 14 | import java.io.File; 15 | import java.io.IOException; 16 | import java.io.OutputStream; 17 | import java.nio.file.Files; 18 | import java.nio.file.Paths; 19 | import java.util.Map; 20 | import java.util.Optional; 21 | import java.util.Set; 22 | import java.util.concurrent.Callable; 23 | 24 | public class ShellProcess implements Callable { 25 | private final Map mCopiedHyperParameter; 26 | private final Set mMultiValueKeys; 27 | private final String mCmdString; 28 | private final String mWorkingDir; 29 | 30 | private final CheckPointManager mCheckPointManager; 31 | 32 | public ShellProcess(final Map copiedHyperParameter, 33 | final Set multiValueKeys, final String cmdString, 34 | final String workingDir, CheckPointManager mCheckPointManager) { 35 | mCopiedHyperParameter = copiedHyperParameter; 36 | mMultiValueKeys = multiValueKeys; 37 | mCmdString = cmdString; 38 | mWorkingDir = workingDir; 39 | this.mCheckPointManager = mCheckPointManager; 40 | } 41 | 42 | public float innerCall() { 43 | final Optional cachedResult = 44 | HyperParamResultManager.getResult(mCopiedHyperParameter); 45 | if (cachedResult.isPresent()) { 46 | return cachedResult.get(); 47 | } 48 | 49 | final String configFilePath = GeneratedFileManager 50 | .getHyperParameterConfigFileFullPath(mCopiedHyperParameter, mMultiValueKeys); 51 | final String newCmdString = mCmdString.replace("{}", configFilePath); 52 | final HyperParameterConfig config = new HyperParameterConfig(configFilePath); 53 | config.write(mCopiedHyperParameter); 54 | final String logFileFullPath = GeneratedFileManager 55 | .getLogFileFullPath(mCopiedHyperParameter, mMultiValueKeys); 56 | System.out.println("logFileFullPath:" + logFileFullPath); 57 | try (OutputStream os = Files.newOutputStream(Paths.get(logFileFullPath))) { 58 | final DefaultExecutor executor = new DefaultExecutor(); 59 | executor.setStreamHandler(new PumpStreamHandler(os)); 60 | if (mWorkingDir != null) { 61 | executor.setWorkingDirectory(new File(mWorkingDir)); 62 | } 63 | System.out.println("begin to execute " + newCmdString); 64 | final org.apache.commons.exec.CommandLine commandLine = 65 | org.apache.commons.exec.CommandLine 66 | .parse(newCmdString); 67 | try { 68 | executor.execute(commandLine); 69 | } catch (final ExecuteException e) { 70 | System.out.println(e.getMessage()); 71 | } catch (final IOException e) { 72 | throw new IllegalStateException(e); 73 | } 74 | 75 | final String log = FileUtils.readFileToString(new File(logFileFullPath), 76 | Charsets.UTF_8); 77 | final float result = Utils.logResult(log); 78 | 79 | HyperParamResultManager.putResult(mCopiedHyperParameter, result); 80 | if (mCheckPointManager != null) { 81 | mCheckPointManager.saveIncrementally(); 82 | } 83 | 84 | return result; 85 | } catch (final IOException e) { 86 | throw new IllegalStateException(e); 87 | } 88 | } 89 | 90 | @Override 91 | public Float call() { 92 | try { 93 | return innerCall(); 94 | } catch (final RuntimeException e) { 95 | e.printStackTrace(); 96 | throw e; 97 | } 98 | } 99 | } 100 | -------------------------------------------------------------------------------- /src/main/java/com/hljunlp/laozhongyi/checkpoint/CheckPointData.java: -------------------------------------------------------------------------------- 1 | package com.hljunlp.laozhongyi.checkpoint; 2 | 3 | import com.hljunlp.laozhongyi.HyperParamResultManager; 4 | 5 | import java.util.HashMap; 6 | import java.util.Map; 7 | 8 | public class CheckPointData { 9 | private final Map currentHyperParameters; 10 | private final int currentHyperParametersIndex; 11 | private final Map bestHyperParameters; 12 | private final float bestScore; 13 | 14 | private final Map, Float> hyperParametersToScore; 15 | 16 | public CheckPointData(final Map currentHyperParameters, 17 | final int currentHyperParametersIndex, final Map bestHyperParameters, final float bestScore, 19 | Map, Float> hyperParametersToScore) { 20 | this.currentHyperParameters = currentHyperParameters; 21 | this.currentHyperParametersIndex = currentHyperParametersIndex; 22 | this.bestHyperParameters = bestHyperParameters; 23 | this.bestScore = bestScore; 24 | this.hyperParametersToScore = hyperParametersToScore; 25 | } 26 | 27 | public Map getCurrentHyperParameters() { 28 | return currentHyperParameters; 29 | } 30 | 31 | public int getCurrentHyperParametersIndex() { 32 | return currentHyperParametersIndex; 33 | } 34 | 35 | public Map getBestHyperParameters() { 36 | return bestHyperParameters; 37 | } 38 | 39 | public float getBestScore() { 40 | return bestScore; 41 | } 42 | 43 | public Map, Float> getHyperParametersToScore() { 44 | return hyperParametersToScore; 45 | } 46 | 47 | @Override 48 | public String toString() { 49 | return "CheckPointData{" + 50 | "currentHyperParameters=" + currentHyperParameters + 51 | ", currentHyperParametersIndex=" + currentHyperParametersIndex + 52 | ", bestHyperParameters=" + bestHyperParameters + 53 | ", bestScore=" + bestScore + 54 | ", hyperParametersToScore=" + hyperParametersToScore + 55 | '}'; 56 | } 57 | 58 | public CheckPointData deepCopy() { 59 | // Create new instances of the maps 60 | Map newCurrentHyperParameters = new HashMap<>(); 61 | for (Map.Entry entry : currentHyperParameters.entrySet()) { 62 | String key = entry.getKey(); 63 | String value = entry.getValue(); 64 | newCurrentHyperParameters.put(key, value); 65 | } 66 | 67 | Map newBestHyperParameters = new HashMap<>(); 68 | for (Map.Entry entry : bestHyperParameters.entrySet()) { 69 | String key = entry.getKey(); 70 | String value = entry.getValue(); 71 | newBestHyperParameters.put(key, value); 72 | } 73 | 74 | Map, Float> newHyperParametersToScore = new HashMap<>(); 75 | for (Map.Entry, Float> entry : hyperParametersToScore.entrySet()) { 76 | Map key = new HashMap<>(); 77 | for (Map.Entry innerEntry : entry.getKey().entrySet()) { 78 | String innerKey = innerEntry.getKey(); 79 | String innerValue = innerEntry.getValue(); 80 | key.put(innerKey, innerValue); 81 | } 82 | float value = entry.getValue(); 83 | newHyperParametersToScore.put(key, value); 84 | } 85 | 86 | // Create a new instance of CheckPointData with the copied data 87 | return new CheckPointData(newCurrentHyperParameters, currentHyperParametersIndex, 88 | newBestHyperParameters, bestScore, newHyperParametersToScore); 89 | } 90 | 91 | public synchronized CheckPointData addHyperParameterToScoreMap() { 92 | CheckPointData newCheckPointData = deepCopy(); 93 | for (Map.Entry, Float> entry : 94 | HyperParamResultManager.deepCopyResults().entrySet()) { 95 | newCheckPointData.getHyperParametersToScore().put(entry.getKey(), entry.getValue()); 96 | } 97 | return newCheckPointData; 98 | } 99 | } 100 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Laozhongyi is a simple but effective automatic hyper-parameter tuning program based on grid search and simulated annealing. 2 | # Why use Laozhongyi? 3 | * There are lots of boring things in hyperparameter tuning, e.g., configuring a bunch of values for hyperparameters, naming a bunch of log files, launching a bunch of processes, comparing results among log files, etc. So manual operation is both time consuming and error prone. 4 | * Some existed automatic hyperparameter tuning programs rely on Python, while Laozhongyi only relies on string parsing. Thus you can use laozhongyi with any programming language and any deep learning library. 5 | # Getting started 6 | We will illustrate how to use Laozhongyi in the following. 7 | ## An example of the tuning range's config file 8 | ``` 9 | lr,0.01,0.001,0.0001 10 | dropout,0,0.1,0.2,0.3,0.4,0.5 11 | batch_size,1,2,4,8,16,32,64 12 | pretrained,/home/xxx/modelxxx 13 | ``` 14 | Each line is a hyperparameter name followed by a list of values to be tuned, separated by commas. In particular, if the parameter has only one value, it means a fixed value and will not be tuned. 15 | ## Requirements for your program to be tuned 16 | * Your program should exit when it performs sufficiently well (e.g., the F1 value on the validation set has not improved for ten consecutive epochs). Otherwise, Laozhongyi will abort your program after it reaches the upper elapsed time limit, which will harm tuning efficiency. 17 | * Your program should output the log to standard output (Laozhongyi will redirect the standard output to a corresponding log file). The log should contain strings such as laozhongyi_0.8, where 0.8 means the best performance on the validation set. 18 | * Your program should parse the hyperparameter config file generated by Laozhongyi and take the path of the hyperparameter config file as the command line argument. 19 | 20 | An example of the generated hyperparameter config file is as follows: 21 | ``` 22 | lr = 0.001 23 | dropout = 0.1 24 | batch_size = 64 25 | pretrained = /home/xxx/modelxxx 26 | ``` 27 | ## Run 28 | This project is built with Java 8 and Maven, so you can run *mvn clean package* to generate the target folder, or download it from [releases](https://github.com/chncwang/laozhongyi/releases). 29 | The command line arguments of Laozhongyi are listed as follows: 30 | ``` 31 | usage: laozhonghi 32 | -c cmd 33 | -rt program runtime upper bound in minutes 34 | -s scope file path 35 | -sar simulated annealing ratio 36 | -sat simulated annealing initial temperature 37 | -strategy base or sa 38 | -wd working directory 39 | ``` 40 | * -c means your program's command line which should be quoted, e.g., "python3 train.py -train train.txt -dev dev.txt -test test.txt -hyper {}",where {} will be replaced with the hyperparameter config file's path by Laozhongyi. 41 | * -rt means the upper limit of the running time of each process in minutes. For example, *-rt 20* means that if a process does not exit after 20 minutes, it will be aborted. 42 | * -s means the path to the config file for the hyperparameter tuning range. 43 | * -strategy means the search strategy, with *base* referring to the coordinate descent method, and *sa* referring to the simulated annealing method. 44 | * -sar means the temperature's decay rate. 45 | * -sat means the initial temperature. 46 | 47 | A complete example is as follows: 48 | ```Bash 49 | cd target 50 | java -cp "*:lib/*" com.hljunlp.laozhongyi.Laozhongyi -s /home/wqs/laozhongyi.config\ 51 | -c "python3 train.py -train train.txt -dev dev.txt -test test.txt -hyper {}"\ 52 | -sar 0.9 -sat 1 -strategy sa -rt 5 53 | ``` 54 | We recommend you to use *screen* to run Laozhongyi. 55 | 56 | When Laozhongyi starts, it will generate the log directory with the timestamp suffix and the hyperparameter config directory in the home directory. 57 | # Features 58 | ## Process Management 59 | * Laozhongyi supports multi-process hyperparameter-tuning, currently up to 8 processes. 60 | * Your program should exit when performance is unlikely to improve any further, or it will be killed when reaching the elapsed time limit. 61 | ## Search Strategy 62 | ### Coordinate Descent 63 | The coordinate descent method tries all the hyperparameters in a loop. For each hyperparameter, Laozhongyi tries all of its values and selects the one that performs best. The algorithm stops until the selected values of all the hyperparameters are no longer changed. 64 | ### Simulated Annealing 65 | To alleviate the problem that the coordinate descent method is easy to converge to the local optimal solution, we introduce the strategy of simulated annealing. 66 | # Question? 67 | Email: chncwang@gmail.com 68 | -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 6 | 4.0.0 7 | hljunlp 8 | laozhongyi 9 | 0.0.1-SNAPSHOT 10 | laozhongyi 11 | jar 12 | 13 | 14 | junit 15 | junit 16 | 4.11 17 | test 18 | 19 | 20 | com.google.guava 21 | guava 22 | 18.0 23 | 24 | 25 | commons-io 26 | commons-io 27 | 2.4 28 | 29 | 30 | org.apache.commons 31 | commons-lang3 32 | 3.4 33 | 34 | 35 | commons-cli 36 | commons-cli 37 | 1.4 38 | 39 | 40 | joda-time 41 | joda-time 42 | 2.3 43 | 44 | 45 | org.apache.commons 46 | commons-exec 47 | 1.3 48 | 49 | 50 | org.json 51 | json 52 | 20210307 53 | 54 | 55 | org.junit.jupiter 56 | junit-jupiter-api 57 | 5.8.1 58 | test 59 | 60 | 61 | 62 | UTF-8 63 | UTF-8 64 | 65 | 66 | 67 | 68 | org.apache.maven.plugins 69 | maven-dependency-plugin 70 | 71 | 72 | copy-dependencies 73 | package 74 | 75 | copy-dependencies 76 | 77 | 78 | ${project.build.directory}/lib 79 | 80 | 81 | 82 | 83 | 84 | maven-compiler-plugin 85 | 3.1 86 | 87 | 1.8 88 | 1.8 89 | UTF-8 90 | 91 | 92 | 93 | org.apache.maven.plugins 94 | maven-surefire-plugin 95 | 2.12 96 | 97 | 98 | **/*IntegrationTests.java 99 | 100 | 101 | 102 | 103 | maven-failsafe-plugin 104 | 2.6 105 | 106 | 107 | **/*IntegrationTests.java 108 | 109 | 110 | 111 | 112 | 113 | integration-test 114 | 115 | 116 | 117 | 118 | 119 | org.codehaus.mojo 120 | cobertura-maven-plugin 121 | 2.5.2 122 | 123 | 124 | html 125 | xml 126 | 127 | 128 | 129 | 130 | package 131 | 132 | cobertura 133 | 134 | 135 | 136 | 137 | 138 | org.apache.maven.plugins 139 | maven-enforcer-plugin 140 | 1.0 141 | 142 | 143 | 144 | enforce-ban-circular-dependencies 145 | 146 | enforce 147 | 148 | 149 | 150 | 151 | 152 | true 153 | 154 | 155 | 156 | 157 | 158 | org.codehaus.mojo 159 | extra-enforcer-rules 160 | 1.0-alpha-4 161 | 162 | 163 | 164 | 165 | org.codehaus.mojo 166 | findbugs-maven-plugin 167 | 3.0.0 168 | 169 | false 170 | true 171 | Low 172 | Max 173 | 174 | 175 | 176 | verify 177 | 178 | check 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | -------------------------------------------------------------------------------- /src/main/java/com/hljunlp/laozhongyi/checkpoint/CheckPointManager.java: -------------------------------------------------------------------------------- 1 | package com.hljunlp.laozhongyi.checkpoint; 2 | 3 | import org.json.JSONException; 4 | import org.json.JSONObject; 5 | 6 | import java.io.File; 7 | import java.io.FileWriter; 8 | import java.io.IOException; 9 | import java.nio.file.Files; 10 | import java.nio.file.Paths; 11 | import java.util.*; 12 | 13 | public class CheckPointManager { 14 | private final String mCheckPointPath; 15 | private CheckPointData mCheckPointData; 16 | 17 | /** 18 | * Constructor. 19 | * 20 | * @param checkPointPath The path of the checkpoint file. 21 | */ 22 | public CheckPointManager(final String checkPointPath) { 23 | mCheckPointPath = checkPointPath; 24 | } 25 | 26 | protected JSONObject convertToJSON(final CheckPointData data) { 27 | JSONObject jsonObject = new JSONObject(); 28 | jsonObject.put("currentHyperParameters", data.getCurrentHyperParameters()); 29 | jsonObject.put("currentHyperParametersIndex", data.getCurrentHyperParametersIndex()); 30 | jsonObject.put("bestHyperParameters", data.getBestHyperParameters()); 31 | jsonObject.put("bestScore", data.getBestScore()); 32 | Map hyperParametersToScore = new HashMap<>(); 33 | for (Map.Entry, Float> entry : data.getHyperParametersToScore() 34 | .entrySet()) { 35 | hyperParametersToScore.put(mapToString(entry.getKey()), entry.getValue().toString()); 36 | } 37 | jsonObject.put("hyperParametersToScore", hyperParametersToScore); 38 | return jsonObject; 39 | } 40 | 41 | /** 42 | * Save the checkpoint as a JSON file. 43 | */ 44 | public synchronized void save(final CheckPointData data) { 45 | String savePath = mCheckPointPath + "." + System.currentTimeMillis(); 46 | save(data, savePath); 47 | } 48 | 49 | public synchronized void save(final CheckPointData data, final String savePath) { 50 | mCheckPointData = data.deepCopy(); 51 | JSONObject jsonObject = convertToJSON(data); 52 | try { 53 | FileWriter fileWriter = new FileWriter(savePath); 54 | fileWriter.write(jsonObject.toString()); 55 | fileWriter.close(); 56 | } catch (IOException e) { 57 | throw new RuntimeException(e); 58 | } 59 | } 60 | 61 | public synchronized void saveIncrementally() { 62 | if (mCheckPointData == null) { 63 | return; 64 | } 65 | mCheckPointData = mCheckPointData.addHyperParameterToScoreMap(); 66 | save(mCheckPointData); 67 | } 68 | 69 | protected static Map jsonObjectToMap(JSONObject jsonObject) { 70 | Map map = new HashMap<>(); 71 | Iterator keys = jsonObject.keys(); 72 | while (keys.hasNext()) { 73 | String key = keys.next(); 74 | String value = jsonObject.getString(key); 75 | map.put(key, value); 76 | } 77 | return map; 78 | } 79 | 80 | protected CheckPointData convertFromJSON(JSONObject jsonObject) { 81 | Map currentHyperParameters = jsonObjectToMap(jsonObject.getJSONObject( 82 | "currentHyperParameters")); 83 | int currentHyperParametersIndex = jsonObject.getInt("currentHyperParametersIndex"); 84 | Map bestHyperParameters = jsonObjectToMap(jsonObject.getJSONObject( 85 | "bestHyperParameters")); 86 | float bestScore = (float) jsonObject.getDouble("bestScore"); 87 | 88 | Map, Float> hyperParametersToScore = new HashMap<>(); 89 | if (jsonObject.has("hyperParametersToScore")) { 90 | JSONObject hyperParametersToScoreJSONObject = jsonObject.getJSONObject( 91 | "hyperParametersToScore"); 92 | Map map = hyperParametersToScoreJSONObject.toMap(); 93 | 94 | for (Map.Entry entry : map.entrySet()) { 95 | String keyString = entry.getKey(); 96 | Float value = Float.parseFloat(entry.getValue().toString()); 97 | Map key = stringToMap(keyString); 98 | hyperParametersToScore.put(key, value); 99 | } 100 | } 101 | 102 | return new CheckPointData(currentHyperParameters, currentHyperParametersIndex, 103 | bestHyperParameters, bestScore, hyperParametersToScore); 104 | } 105 | 106 | public CheckPointData load() { 107 | return load(mCheckPointPath); 108 | } 109 | 110 | public static File getFileWithLargestSuffix(String fullFileName) { 111 | final String dir = fullFileName.substring(0, fullFileName.lastIndexOf(File.separator)); 112 | System.out.println("getFileWithLargestSuffix: dir = " + dir); 113 | File[] files = Objects.requireNonNull(new File(dir).listFiles()); 114 | System.out.println("getFileWithLargestSuffix: files = " + Arrays.toString(files)); 115 | final List fullFileNamesUnderDir = new ArrayList<>(); 116 | for (File file : files) { 117 | if (file.getAbsolutePath().startsWith(fullFileName) && file.getAbsolutePath().charAt( 118 | fullFileName.length()) == '.') { 119 | fullFileNamesUnderDir.add(file.getAbsolutePath()); 120 | } 121 | } 122 | long largestSuffix = -1; 123 | String largestSuffixFileName = null; 124 | for (String fileName : fullFileNamesUnderDir) { 125 | String suffix = fileName.substring(fileName.lastIndexOf(".") + 1); 126 | long suffixLong = Long.parseLong(suffix); 127 | if (suffixLong > largestSuffix) { 128 | largestSuffix = suffixLong; 129 | largestSuffixFileName = fileName; 130 | } 131 | } 132 | if (largestSuffixFileName == null) { 133 | return null; 134 | } else { 135 | return new File(largestSuffixFileName); 136 | } 137 | } 138 | 139 | public CheckPointData load(String loadPath) { 140 | File file = getFileWithLargestSuffix(loadPath); 141 | if (file == null) { 142 | System.out.println("No checkpoint found at " + loadPath); 143 | return null; 144 | } 145 | try { 146 | System.out.println("Loading checkpoint from " + file.getAbsolutePath()); 147 | String jsonString = new String(Files.readAllBytes(Paths.get(file.getAbsolutePath()))); 148 | JSONObject jsonObject = new JSONObject(jsonString); 149 | return convertFromJSON(jsonObject); 150 | } catch (IOException | JSONException e) { 151 | throw new RuntimeException(e); 152 | } 153 | } 154 | 155 | public static String mapToString(Map passedMap) { 156 | // Create a TreeMap to fix the order of the keys in the string. 157 | Map map = new TreeMap<>(passedMap); 158 | StringBuilder sb = new StringBuilder(); 159 | for (Map.Entry entry : map.entrySet()) { 160 | sb.append(entry.getKey()); 161 | sb.append("="); 162 | sb.append(entry.getValue()); 163 | sb.append(","); 164 | } 165 | if (sb.length() > 0) { 166 | sb.deleteCharAt(sb.length() - 1); 167 | } 168 | return sb.toString(); 169 | } 170 | 171 | public static Map stringToMap(String str) { 172 | Map map = new HashMap<>(); 173 | String[] entries = str.split(","); 174 | for (String entry : entries) { 175 | String[] parts = entry.split("="); 176 | if (parts.length == 2) { 177 | map.put(parts[0], parts[1]); 178 | } 179 | } 180 | return map; 181 | } 182 | } 183 | -------------------------------------------------------------------------------- /src/main/java/com/hljunlp/laozhongyi/Laozhongyi.java: -------------------------------------------------------------------------------- 1 | package com.hljunlp.laozhongyi; 2 | 3 | import com.google.common.base.Preconditions; 4 | import com.google.common.collect.Lists; 5 | import com.google.common.collect.Maps; 6 | import com.google.common.collect.Sets; 7 | import com.hljunlp.laozhongyi.checkpoint.CheckPointData; 8 | import com.hljunlp.laozhongyi.checkpoint.CheckPointManager; 9 | import com.hljunlp.laozhongyi.checkpoint.SimulatedAnnealingCheckPointData; 10 | import com.hljunlp.laozhongyi.checkpoint.SimulatedAnnealingCheckPointManager; 11 | import com.hljunlp.laozhongyi.process.ParamsAndCallable; 12 | import com.hljunlp.laozhongyi.process.ShellProcess; 13 | import com.hljunlp.laozhongyi.strategy.BaseStrategy; 14 | import com.hljunlp.laozhongyi.strategy.SimulatedAnnealingStrategy; 15 | import com.hljunlp.laozhongyi.strategy.Strategy; 16 | import com.hljunlp.laozhongyi.strategy.TraiditionalSimulatedAnnealingStrategy; 17 | import org.apache.commons.cli.*; 18 | import org.apache.commons.lang3.StringUtils; 19 | import org.apache.commons.lang3.tuple.ImmutablePair; 20 | import org.apache.commons.lang3.tuple.MutablePair; 21 | import org.apache.commons.lang3.tuple.Pair; 22 | 23 | import java.util.*; 24 | import java.util.Map.Entry; 25 | import java.util.concurrent.ExecutionException; 26 | import java.util.concurrent.ExecutorService; 27 | import java.util.concurrent.Executors; 28 | import java.util.concurrent.Future; 29 | 30 | public class Laozhongyi { 31 | public static void main(final String[] args) { 32 | final Options options = new Options(); 33 | final Option scope = new Option("s", true, "scope file path"); 34 | scope.setRequired(true); 35 | options.addOption(scope); 36 | 37 | final Option cmdOpt = new Option("c", true, "cmd"); 38 | cmdOpt.setRequired(true); 39 | options.addOption(cmdOpt); 40 | 41 | final Option workingDir = new Option("wd", true, "working directory"); 42 | workingDir.setRequired(false); 43 | options.addOption(workingDir); 44 | 45 | final Option strategyOpt = new Option("strategy", true, "base or sa or vsa"); 46 | strategyOpt.setRequired(true); 47 | options.addOption(strategyOpt); 48 | 49 | final Option saTOpt = new Option("sat", true, "simulated annealing initial temperature"); 50 | saTOpt.setRequired(false); 51 | options.addOption(saTOpt); 52 | 53 | final Option saROpt = new Option("sar", true, "simulated annealing ratio"); 54 | saTOpt.setRequired(false); 55 | options.addOption(saROpt); 56 | 57 | final Option processCountOpt = new Option("pc", true, "process count upper bound"); 58 | processCountOpt.setRequired(true); 59 | options.addOption(processCountOpt); 60 | 61 | final Option checkPointFilePathOpt = new Option("cp", true, "check point file path"); 62 | checkPointFilePathOpt.setRequired(false); 63 | options.addOption(checkPointFilePathOpt); 64 | 65 | final CommandLineParser parser = new DefaultParser(); 66 | final CommandLine cmd; 67 | try { 68 | cmd = parser.parse(options, args); 69 | } catch (final ParseException e) { 70 | final HelpFormatter formatter = new HelpFormatter(); 71 | formatter.printHelp("laozhongyi", options); 72 | throw new IllegalStateException(e); 73 | } 74 | 75 | final Option[] parsedOptions = cmd.getOptions(); 76 | for (final Option option : parsedOptions) { 77 | System.out.println(option.getOpt() + ":" + option.getValue()); 78 | } 79 | 80 | final String scopeFilePath = cmd.getOptionValue("s"); 81 | final List items = HyperParameterScopeConfigReader 82 | .read(scopeFilePath); 83 | 84 | Strategy strategy; 85 | final String strategyStr = cmd.getOptionValue("strategy"); 86 | if (strategyStr.equals("base")) { 87 | strategy = new BaseStrategy(); 88 | } else if (strategyStr.equals("sa")) { 89 | final float ratio = Float.parseFloat(cmd.getOptionValue("sar")); 90 | final float t = Float.parseFloat(cmd.getOptionValue("sat")); 91 | strategy = new TraiditionalSimulatedAnnealingStrategy(ratio, t); 92 | } else { 93 | throw new IllegalArgumentException("strategy param is " + strategyStr); 94 | } 95 | 96 | final String workingDirStr = cmd.getOptionValue("wd"); 97 | 98 | final String programCmd = cmd.getOptionValue("c"); 99 | 100 | final int processCountLimit = Integer.parseInt(cmd.getOptionValue("pc")); 101 | 102 | final String checkPointFilePath = cmd.getOptionValue("cp"); 103 | CheckPointManager checkPointManager = null; 104 | CheckPointData initialCheckPointData = null; 105 | if (checkPointFilePath != null) { 106 | checkPointManager = strategyStr.equals("base") ? 107 | new CheckPointManager(checkPointFilePath) : 108 | new SimulatedAnnealingCheckPointManager(checkPointFilePath); 109 | initialCheckPointData = checkPointManager.load(); 110 | } 111 | 112 | if (initialCheckPointData != null && strategyStr.equals("sa")) { 113 | SimulatedAnnealingCheckPointData saCheckPointData = 114 | (SimulatedAnnealingCheckPointData) initialCheckPointData; 115 | strategy = new TraiditionalSimulatedAnnealingStrategy(saCheckPointData.getDecayRate(), 116 | saCheckPointData.getTemperature()); 117 | } 118 | 119 | GeneratedFileManager.mkdirForLog(); 120 | GeneratedFileManager.mkdirForHyperParameterConfig(); 121 | 122 | final MutablePair, Float> bestPair = initialCheckPointData == null ? 123 | MutablePair.of(Collections.emptyMap(), -1.f) : 124 | MutablePair.of(initialCheckPointData.getBestHyperParameters(), 125 | initialCheckPointData.getBestScore()); 126 | 127 | Map params; 128 | if (initialCheckPointData == null) { 129 | final Random random = new Random(); 130 | params = initHyperParameters(items, random); 131 | } else { 132 | params = initialCheckPointData.getCurrentHyperParameters(); 133 | } 134 | final Set multiValueKeys = getMultiValueKeys(items); 135 | final ExecutorService executorService = Executors.newFixedThreadPool(processCountLimit); 136 | boolean isFirstTry = initialCheckPointData == null; 137 | 138 | int currentHyperParameterIndex = initialCheckPointData == null ? 0 : 139 | initialCheckPointData.getCurrentHyperParametersIndex(); 140 | 141 | if (initialCheckPointData != null) { 142 | for (Entry, Float> entry : 143 | initialCheckPointData.getHyperParametersToScore().entrySet()) { 144 | HyperParamResultManager.putResult(entry.getKey(), entry.getValue()); 145 | } 146 | } 147 | 148 | int count = 0; 149 | while (true) { 150 | if (++count > 100000) { 151 | break; 152 | } 153 | String modifiedKey = StringUtils.EMPTY; 154 | int itemIndex = -1; 155 | for (final HyperParameterScopeItem item : items) { 156 | ++itemIndex; 157 | if (itemIndex < currentHyperParameterIndex && count == 1) { 158 | continue; 159 | } 160 | Preconditions.checkState(!item.getValues().isEmpty()); 161 | if (item.getValues().size() == 1) { 162 | continue; 163 | } 164 | System.out.println("item:" + item); 165 | final Pair, Float> result = tryItem(item, multiValueKeys, 166 | params, programCmd, executorService, strategy, bestPair, workingDirStr, 167 | isFirstTry, checkPointManager); 168 | isFirstTry = false; 169 | System.out.println("key:" + item.getKey() + "\nsuitable value:" + result.getLeft() 170 | + " result:" + result.getRight()); 171 | if (!result.getLeft().equals(params)) { 172 | modifiedKey = item.getKey(); 173 | params = result.getLeft(); 174 | } 175 | System.out.println("suitable params now:"); 176 | for (final Entry entry : params.entrySet()) { 177 | System.out.println(entry.getKey() + ": " + entry.getValue()); 178 | } 179 | System.out.println("best result:" + bestPair.getRight()); 180 | System.out.println("best params now:"); 181 | for (final Entry entry : bestPair.getLeft().entrySet()) { 182 | System.out.println(entry.getKey() + ": " + entry.getValue()); 183 | } 184 | final String logFileFullPath = GeneratedFileManager 185 | .getLogFileFullPath(bestPair.getLeft(), multiValueKeys); 186 | System.out.println("best log path is " + logFileFullPath); 187 | final String hyperPath = GeneratedFileManager 188 | .getHyperParameterConfigFileFullPath(bestPair.getLeft(), multiValueKeys); 189 | System.out.println("best hyper path is " + hyperPath); 190 | 191 | if (checkPointManager != null) { 192 | final CheckPointData checkPointData; 193 | if (strategyStr.equals("sa")) { 194 | final SimulatedAnnealingStrategy saStrategy = 195 | (SimulatedAnnealingStrategy) strategy; 196 | checkPointData = 197 | new SimulatedAnnealingCheckPointData(saStrategy.getTemperature(), 198 | saStrategy.getDecayRate(), params, itemIndex + 1, 199 | bestPair.getLeft(), bestPair.getRight(), 200 | HyperParamResultManager.deepCopyResults()); 201 | } else { 202 | checkPointData = new CheckPointData(params, itemIndex + 1, 203 | bestPair.getLeft(), bestPair.getRight(), 204 | HyperParamResultManager.deepCopyResults()); 205 | } 206 | checkPointManager.save(checkPointData); 207 | } 208 | } 209 | strategy.iterationEnd(); 210 | 211 | if (modifiedKey.equals(StringUtils.EMPTY)) { 212 | if (params.equals(bestPair.getLeft())) { 213 | break; 214 | } else { 215 | params = bestPair.getLeft(); 216 | } 217 | } 218 | } 219 | 220 | System.out.println("hyperparameter adjusted, the best result is " + bestPair.getRight()); 221 | System.out.println("best hyperparameters:"); 222 | for (final Entry entry : bestPair.getLeft().entrySet()) { 223 | System.out.println(entry.getKey() + ": " + entry.getValue()); 224 | } 225 | executorService.shutdown(); 226 | } 227 | 228 | private static Map initHyperParameters( 229 | final List items, final Random random) { 230 | final Map result = Maps.newTreeMap(); 231 | for (final HyperParameterScopeItem item : items) { 232 | final int index = random.nextInt(item.getValues().size()); 233 | result.put(item.getKey(), item.getValues().get(index)); 234 | } 235 | return result; 236 | } 237 | 238 | private static Pair, Float> tryItem(final HyperParameterScopeItem item, 239 | final Set multiValueKeys, 240 | final Map currentHyperParameter, 241 | final String cmdString, 242 | final ExecutorService executorService, 243 | final Strategy strategy, 244 | final MutablePair 245 | , Float> best, 246 | final String workingDir, 247 | boolean isFirstTry, CheckPointManager 248 | checkPointManager) { 249 | Preconditions.checkArgument(item.getValues().size() > 1); 250 | 251 | final List> futures = Lists.newArrayList(); 252 | final List paramsAndCallables = Lists.newArrayList(); 253 | 254 | for (final String value : item.getValues()) { 255 | final Map copiedHyperParameter = Utils 256 | .modifiedNewMap(currentHyperParameter, item.getKey(), value); 257 | 258 | final ShellProcess callable = new ShellProcess(copiedHyperParameter, multiValueKeys, 259 | cmdString, workingDir, checkPointManager); 260 | final Future future = executorService.submit(callable); 261 | futures.add(future); 262 | paramsAndCallables.add(new ParamsAndCallable(copiedHyperParameter, callable)); 263 | } 264 | 265 | final List results = Lists.newArrayList(); 266 | for (int i = 0; i < futures.size(); ++i) { 267 | final Float futureResult; 268 | try { 269 | futureResult = futures.get(i).get(); 270 | System.out.println("key:" + item.getKey() + " value:" 271 | + paramsAndCallables.get(i).getParams().get(item.getKey()) 272 | + " futureResult:" + futureResult); 273 | } catch (final InterruptedException | ExecutionException e) { 274 | throw new IllegalStateException(e); 275 | } 276 | 277 | results.add(futureResult); 278 | } 279 | 280 | float localBestResult = -1; 281 | Map localBestParams = Collections.emptyMap(); 282 | for (int i = 0; i < results.size(); ++i) { 283 | if (results.get(i) > localBestResult) { 284 | localBestResult = results.get(i); 285 | localBestParams = paramsAndCallables.get(i).getParams(); 286 | } 287 | } 288 | Preconditions.checkState(!localBestParams.isEmpty()); 289 | 290 | synchronized (Laozhongyi.class) { 291 | if (localBestResult > best.getRight()) { 292 | best.setRight(localBestResult); 293 | best.setLeft(localBestParams); 294 | strategy.storeBest(); 295 | } 296 | } 297 | 298 | final String originalValue = currentHyperParameter.get(item.getKey()); 299 | final int originalIndex = item.getValues().indexOf(originalValue); 300 | final int suitableIndex = strategy.chooseSuitableIndex(results, originalIndex, isFirstTry); 301 | 302 | return ImmutablePair.of(paramsAndCallables.get(suitableIndex).getParams(), 303 | results.get(suitableIndex)); 304 | } 305 | 306 | private static Set getMultiValueKeys(final List items) { 307 | final Set keys = Sets.newTreeSet(); 308 | for (final HyperParameterScopeItem item : items) { 309 | if (item.getValues().size() > 1) { 310 | keys.add(item.getKey()); 311 | } 312 | } 313 | 314 | return keys; 315 | } 316 | } 317 | --------------------------------------------------------------------------------