├── .gitignore ├── edge-algorithm ├── README.md ├── pom.xml └── src │ └── main │ └── java │ └── cn │ └── edu │ └── scut │ ├── agent │ ├── IMAAgent.java │ ├── MABuffer.java │ ├── mappo │ │ ├── MAPPOAgent.java │ │ └── MAPPOConfig.java │ └── masac │ │ ├── AlphaBlock.java │ │ ├── MASACAgent.java │ │ └── MASACConfiguration.java │ ├── config │ └── DJLConfig.java │ ├── service │ └── TransitionService.java │ └── util │ ├── DJLUtils.java │ ├── FileUtils.java │ └── PlotUtils.java ├── edge-api ├── README.md ├── pom.xml └── src │ └── main │ └── java │ └── cn │ └── edu │ └── scut │ ├── bean │ ├── Constants.java │ ├── EdgeNode.java │ ├── Link.java │ ├── RatcVo.java │ ├── StoreConstants.java │ ├── Task.java │ └── TaskStatus.java │ ├── mapper │ ├── EdgeNodeMapper.java │ ├── LinkMapper.java │ └── TaskMapper.java │ ├── service │ ├── EdgeNodeService.java │ ├── LinkService.java │ └── TaskService.java │ └── util │ └── ArrayUtils.java ├── edge-controller ├── Dockerfile ├── README.md ├── pom.xml └── src │ └── main │ ├── java │ └── cn │ │ └── edu │ │ └── scut │ │ ├── ControllerApp.java │ │ ├── config │ │ ├── DataGenerationConfig.java │ │ └── SpringConfig.java │ │ ├── controller │ │ └── EdgeNodeController.java │ │ ├── service │ │ ├── EdgeConfigService.java │ │ └── UserService.java │ │ ├── thread │ │ └── UserRunnable.java │ │ └── utils │ │ └── SpringBeanUtils.java │ └── resources │ ├── META-INF │ └── MANIFEST.MF │ └── bootstrap.yaml ├── edge-experiment ├── Dockerfile ├── README.md ├── pom.xml ├── src │ └── main │ │ ├── java │ │ └── cn │ │ │ └── edu │ │ │ └── scut │ │ │ ├── ExperimentApp.java │ │ │ ├── config │ │ │ ├── HadoopConfig.java │ │ │ └── SpringConfig.java │ │ │ ├── controller │ │ │ └── ModelController.java │ │ │ ├── runner │ │ │ ├── IRunner.java │ │ │ ├── OfflineMARLTrainingRunner.java │ │ │ ├── OnlineHeuristicDataRunner.java │ │ │ ├── OnlineHeuristicTestRunner.java │ │ │ ├── OnlineMARLTestingRunner.java │ │ │ └── OnlineMARLTrainingRunner.java │ │ │ ├── service │ │ │ ├── PlotService.java │ │ │ └── RunnerService.java │ │ │ └── util │ │ │ ├── DateTimeUtils.java │ │ │ ├── MathUtils.java │ │ │ └── SpringBeanUtils.java │ │ └── resources │ │ └── bootstrap.yaml └── test-results │ └── buffer │ └── test-buffer.txt ├── edge-node ├── Dockerfile ├── pom.xml └── src │ └── main │ ├── java │ └── cn │ │ └── edu │ │ └── scut │ │ ├── EdgeNodeApp.java │ │ ├── bean │ │ └── EdgeNodeSystem.java │ │ ├── config │ │ ├── Config.java │ │ └── HadoopConfig.java │ │ ├── controller │ │ ├── CmdController.java │ │ ├── EdgeNodeController.java │ │ ├── FileController.java │ │ ├── ModelController.java │ │ └── UserController.java │ │ ├── queue │ │ ├── ExecutionQueue.java │ │ └── TransmissionQueue.java │ │ ├── scheduler │ │ ├── DRLScheduler.java │ │ ├── IScheduler.java │ │ ├── RandomScheduler.java │ │ ├── ReactiveScheduler.java │ │ └── ReliabilityTwoChoice.java │ │ ├── service │ │ └── EdgeNodeSystemService.java │ │ ├── thread │ │ ├── ExecutionRunnable.java │ │ └── TransmissionRunnable.java │ │ └── utils │ │ └── SpringBeanUtils.java │ └── resources │ ├── META-INF │ └── MANIFEST.MF │ └── bootstrap.yaml ├── k8s ├── edge │ ├── edge-controller.yaml │ ├── edge-experiment.yaml │ └── edge-node.yaml ├── mysql.yaml └── nacos.yaml ├── nacos └── DEFAULT_GROUP │ ├── application-edge-computing.yaml │ ├── application-mappo.yaml │ ├── application-masac.yaml │ ├── application-random.yaml │ ├── application-reactive.yaml │ ├── application-reliability-two-choice.yaml │ └── application.yaml ├── pom.xml ├── sql └── create_table.sql ├── start-experiment.sh └── stop-experiment.sh /.gitignore: -------------------------------------------------------------------------------- 1 | results 2 | .idea 3 | edge-algorithm/target 4 | edge-api/target 5 | edge-controller/target 6 | edge-experiment/target 7 | edge-node/target 8 | edge-test/target -------------------------------------------------------------------------------- /edge-algorithm/README.md: -------------------------------------------------------------------------------- 1 | DRL algorithm. 2 | -------------------------------------------------------------------------------- /edge-algorithm/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 6 | edge-computing 7 | cn.edu.scut 8 | 1.0-SNAPSHOT 9 | 10 | 4.0.0 11 | 12 | edge-algorithm 13 | 14 | 15 | 17 16 | 17 17 | UTF-8 18 | 19 | 20 | 21 | 22 | cn.edu.scut 23 | edge-api 24 | 1.0-SNAPSHOT 25 | 26 | 27 | org.springframework.boot 28 | spring-boot-starter 29 | 30 | 31 | org.projectlombok 32 | lombok 33 | 34 | 35 | 36 | ai.djl 37 | api 38 | 39 | 40 | ai.djl.pytorch 41 | pytorch-engine 42 | runtime 43 | 44 | 45 | ai.djl.pytorch 46 | pytorch-native-cu116 47 | linux-x86_64 48 | runtime 49 | 50 | 51 | ai.djl.pytorch 52 | pytorch-jni 53 | runtime 54 | 55 | 56 | 57 | 58 | tech.tablesaw 59 | tablesaw-core 60 | LATEST 61 | 62 | 63 | tech.tablesaw 64 | tablesaw-jsplot 65 | 0.43.1 66 | 67 | 68 | 69 | 70 | 71 | org.springframework.boot 72 | spring-boot-starter-test 73 | test 74 | 75 | 76 | 77 | org.springframework.boot 78 | spring-boot-starter-logging 79 | 80 | 81 | -------------------------------------------------------------------------------- /edge-algorithm/src/main/java/cn/edu/scut/agent/IMAAgent.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut.agent; 2 | 3 | import java.io.InputStream; 4 | 5 | public interface IMAAgent { 6 | 7 | int selectAction(float[] state, int[] availAction, boolean training); 8 | 9 | void train(); 10 | 11 | void saveModel(String flag); 12 | 13 | void loadModel(String flag); 14 | 15 | void saveHdfsModel(String flag); 16 | 17 | void loadHdfsModel(String flag); 18 | 19 | void loadSteamModel(InputStream inputStream, String fileName); 20 | } 21 | -------------------------------------------------------------------------------- /edge-algorithm/src/main/java/cn/edu/scut/agent/MABuffer.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut.agent; 2 | 3 | 4 | import ai.djl.ndarray.NDList; 5 | import ai.djl.ndarray.NDManager; 6 | import ai.djl.ndarray.index.NDIndex; 7 | import ai.djl.ndarray.types.DataType; 8 | import ai.djl.ndarray.types.Shape; 9 | import cn.edu.scut.util.FileUtils; 10 | import lombok.Getter; 11 | import lombok.extern.slf4j.Slf4j; 12 | import org.apache.hadoop.fs.FileSystem; 13 | import org.apache.hadoop.fs.Path; 14 | import org.springframework.beans.factory.InitializingBean; 15 | import org.springframework.beans.factory.annotation.Autowired; 16 | import org.springframework.beans.factory.annotation.Value; 17 | import org.springframework.context.annotation.Lazy; 18 | import org.springframework.stereotype.Component; 19 | 20 | import java.io.IOException; 21 | import java.nio.file.Files; 22 | import java.nio.file.Paths; 23 | import java.util.ArrayList; 24 | import java.util.Collections; 25 | import java.util.Random; 26 | 27 | @Component 28 | @Lazy 29 | @Slf4j 30 | public class MABuffer implements InitializingBean { 31 | 32 | @Value("${rl.state-shape}") 33 | private int stateShape; 34 | 35 | @Value("${edgeComputing.edgeNodeNumber}") 36 | private int agentNumber; 37 | 38 | @Value("${rl.action-shape}") 39 | private int actionShape; 40 | 41 | @Value("${rl.buffer-size}") 42 | private int bufferSize; 43 | 44 | @Value("${rl.batch-size}") 45 | private int batchSize; 46 | 47 | @Value("${hadoop.hdfs.url}") 48 | private String hdfsUrl; 49 | 50 | @Getter 51 | private float[][][] states; 52 | @Getter 53 | private int[][][] actions; 54 | @Getter 55 | private float[][][] rewards; 56 | @Getter 57 | private int[][][] availActions; 58 | @Getter 59 | private float[][][] nextStates; 60 | 61 | private int index = 0; 62 | 63 | private int size = 0; 64 | 65 | @Autowired 66 | private FileSystem fileSystem; 67 | 68 | @Autowired 69 | private Random bufferRandom; 70 | 71 | @Override 72 | public void afterPropertiesSet() { 73 | log.info("buffer-size {} ", bufferSize); 74 | states = new float[bufferSize][agentNumber][stateShape]; 75 | actions = new int[bufferSize][agentNumber][1]; 76 | availActions = new int[bufferSize][agentNumber][actionShape]; 77 | rewards = new float[bufferSize][agentNumber][1]; 78 | nextStates = new float[bufferSize][agentNumber][stateShape]; 79 | } 80 | 81 | public void insert(float[][] state, int[][] action, int[][] availAction, float[][] reward, float[][] nextState) { 82 | states[index] = state; 83 | actions[index] = action; 84 | availActions[index] = availAction; 85 | rewards[index] = reward; 86 | nextStates[index] = nextState; 87 | index = (index + 1) % bufferSize; 88 | size = Math.min(size + 1, bufferSize); 89 | } 90 | 91 | // off-policy 92 | public NDList sample(NDManager manager) { 93 | var list = new ArrayList(); 94 | for (int i = 0; i < size; i++) { 95 | list.add(i); 96 | } 97 | Collections.shuffle(list, bufferRandom); 98 | var batch = new int[batchSize]; 99 | for (int i = 0; i < batchSize; i++) { 100 | batch[i] = list.get(i); 101 | } 102 | 103 | var ndStates = manager.zeros(new Shape(batchSize, agentNumber, stateShape), DataType.FLOAT32); 104 | var ndActions = manager.zeros(new Shape(batchSize, agentNumber, 1), DataType.INT32); 105 | var ndAvailActions = manager.zeros(new Shape(batchSize, agentNumber, actionShape), DataType.INT32); 106 | var ndRewards = manager.zeros(new Shape(batchSize, agentNumber, 1), DataType.FLOAT32); 107 | var ndNextStates = manager.zeros(new Shape(batchSize, agentNumber, stateShape), DataType.FLOAT32); 108 | 109 | for (int i = 0; i < batchSize; i++) { 110 | var ndIndex = new NDIndex(i); 111 | ndStates.set(ndIndex, manager.create(states[batch[i]])); 112 | ndActions.set(ndIndex, manager.create(actions[batch[i]])); 113 | ndAvailActions.set(ndIndex, manager.create(availActions[batch[i]])); 114 | ndRewards.set(ndIndex, manager.create(rewards[batch[i]])); 115 | ndNextStates.set(ndIndex, manager.create(nextStates[batch[i]])); 116 | } 117 | return new NDList(ndStates, ndActions, ndAvailActions, ndRewards, ndNextStates); 118 | } 119 | 120 | // on-policy 121 | public NDList sampleAll(NDManager manager) { 122 | var ndStates = manager.zeros(new Shape(bufferSize, agentNumber, stateShape), DataType.FLOAT32); 123 | var ndActions = manager.zeros(new Shape(bufferSize, agentNumber, 1), DataType.INT32); 124 | var ndAvailActions = manager.zeros(new Shape(bufferSize, agentNumber, actionShape), DataType.INT32); 125 | var ndRewards = manager.zeros(new Shape(bufferSize, agentNumber, 1), DataType.FLOAT32); 126 | var ndNextStates = manager.zeros(new Shape(bufferSize, agentNumber, stateShape), DataType.FLOAT32); 127 | for (int i = 0; i < bufferSize; i++) { 128 | var ndIndex = new NDIndex(i); 129 | ndStates.set(ndIndex, manager.create(states[i])); 130 | ndActions.set(ndIndex, manager.create(actions[i])); 131 | ndAvailActions.set(ndIndex, manager.create(availActions[i])); 132 | ndRewards.set(ndIndex, manager.create(rewards[i])); 133 | ndNextStates.set(ndIndex, manager.create(nextStates[i])); 134 | } 135 | return new NDList(ndStates, ndActions, ndAvailActions, ndRewards, ndNextStates); 136 | } 137 | 138 | public void saveHdfs(String flag) { 139 | var path = Paths.get("results", "buffer", flag); 140 | try { 141 | Files.createDirectories(path); 142 | } catch (IOException e) { 143 | throw new RuntimeException(e); 144 | } 145 | String basePath = "results/buffer/" + flag + "/"; 146 | try { 147 | FileUtils.writeObject(states, Paths.get("results", "buffer", flag, "states.array")); 148 | fileSystem.copyFromLocalFile(true, true, new Path(basePath + "states.array"), new Path(hdfsUrl + "/" + basePath + "states.array")); 149 | 150 | FileUtils.writeObject(actions, Paths.get("results", "buffer", flag, "actions.array")); 151 | fileSystem.copyFromLocalFile(true, true, new Path(basePath + "actions.array"), new Path(hdfsUrl + "/" + basePath + "actions.array")); 152 | 153 | FileUtils.writeObject(availActions, Paths.get("results", "buffer", flag, "availActions.array")); 154 | fileSystem.copyFromLocalFile(true, true, new Path(basePath + "availActions.array"), new Path(hdfsUrl + "/" + basePath + "availActions.array")); 155 | 156 | FileUtils.writeObject(rewards, Paths.get("results", "buffer", flag, "rewards.array")); 157 | fileSystem.copyFromLocalFile(true, true, new Path(basePath + "rewards.array"), new Path(hdfsUrl + "/" + basePath + "rewards.array")); 158 | 159 | FileUtils.writeObject(nextStates, Paths.get("results", "buffer", flag, "nextStates.array")); 160 | fileSystem.copyFromLocalFile(true, true, new Path(basePath + "nextStates.array"), new Path(hdfsUrl + "/" + basePath + "nextStates.array")); 161 | } catch (IOException e) { 162 | throw new RuntimeException(e); 163 | } 164 | } 165 | 166 | public void loadHdfs(String bufferPath) { 167 | var path = Paths.get("results", "buffer", bufferPath); 168 | try { 169 | Files.createDirectories(path); 170 | } catch (IOException e) { 171 | throw new RuntimeException(e); 172 | } 173 | String basePath = "results/buffer/" + bufferPath + "/"; 174 | try { 175 | fileSystem.copyToLocalFile(false, new Path(hdfsUrl + "/" + basePath + "states.array"), new Path(basePath + "states.array"), false); 176 | states = (float[][][]) FileUtils.readObject(Paths.get("results", "buffer", bufferPath, "states.array")); 177 | 178 | fileSystem.copyToLocalFile(false, new Path(hdfsUrl + "/" + basePath + "actions.array"), new Path(basePath + "actions.array"), false); 179 | actions = (int[][][]) FileUtils.readObject(Paths.get("results", "buffer", bufferPath, "actions.array")); 180 | 181 | fileSystem.copyToLocalFile(false, new Path(hdfsUrl + "/" + basePath + "availActions.array"), new Path(basePath + "availActions.array"), false); 182 | availActions = (int[][][]) FileUtils.readObject(Paths.get("results", "buffer", bufferPath, "availActions.array")); 183 | 184 | fileSystem.copyToLocalFile(false, new Path(hdfsUrl + "/" + basePath + "rewards.array"), new Path(basePath + "rewards.array"), false); 185 | rewards = (float[][][]) FileUtils.readObject(Paths.get("results", "buffer", bufferPath, "rewards.array")); 186 | 187 | fileSystem.copyToLocalFile(false, new Path(hdfsUrl + "/" + basePath + "nextStates.array"), new Path(basePath + "nextStates.array"), false); 188 | nextStates = (float[][][]) FileUtils.readObject(Paths.get("results", "buffer", bufferPath, "nextStates.array")); 189 | } catch (IOException e) { 190 | throw new RuntimeException(e); 191 | } 192 | index = bufferSize; 193 | size = bufferSize; 194 | } 195 | } 196 | -------------------------------------------------------------------------------- /edge-algorithm/src/main/java/cn/edu/scut/agent/mappo/MAPPOAgent.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut.agent.mappo; 2 | 3 | import ai.djl.MalformedModelException; 4 | import ai.djl.Model; 5 | import ai.djl.ndarray.NDArray; 6 | import ai.djl.ndarray.NDArrays; 7 | import ai.djl.ndarray.NDList; 8 | import ai.djl.ndarray.NDManager; 9 | import ai.djl.training.TrainingConfig; 10 | import ai.djl.translate.NoopTranslator; 11 | import ai.djl.translate.TranslateException; 12 | import cn.edu.scut.agent.IMAAgent; 13 | import cn.edu.scut.agent.MABuffer; 14 | import cn.edu.scut.util.DJLUtils; 15 | import cn.edu.scut.util.FileUtils; 16 | import lombok.extern.slf4j.Slf4j; 17 | import org.apache.hadoop.fs.FileSystem; 18 | import org.apache.hadoop.fs.Path; 19 | import org.springframework.beans.factory.InitializingBean; 20 | import org.springframework.beans.factory.annotation.Autowired; 21 | import org.springframework.beans.factory.annotation.Value; 22 | import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; 23 | import org.springframework.stereotype.Component; 24 | 25 | import java.io.IOException; 26 | import java.io.InputStream; 27 | import java.nio.file.Paths; 28 | import java.util.Random; 29 | import java.util.concurrent.locks.Lock; 30 | import java.util.concurrent.locks.ReentrantLock; 31 | 32 | @Component 33 | @Slf4j 34 | @ConditionalOnProperty(name = "rl.name", havingValue = "mappo") 35 | public class MAPPOAgent implements IMAAgent, InitializingBean { 36 | 37 | @Autowired 38 | private Random schedulerRandom; 39 | 40 | @Autowired 41 | private NDManager manager; 42 | 43 | @Autowired 44 | private MABuffer buffer; 45 | 46 | @Value("${rl.use-normalized-reward}") 47 | private boolean useNormalizedReward; 48 | 49 | @Value("${rl.epoch}") 50 | private int epoch; 51 | 52 | @Value("${rl.clip}") 53 | private float clip; 54 | 55 | @Value("${rl.use-entropy}") 56 | private boolean useEntropy; 57 | 58 | @Value("${rl.entropy-coef}") 59 | private float entropyCoef; 60 | 61 | @Value("${rl.gamma}") 62 | private float gamma; 63 | 64 | @Value("${edgeComputing.episodeLimit}") 65 | private int episodeLimit; 66 | 67 | @Value("${rl.use-gae}") 68 | private boolean useGae; 69 | 70 | @Value("${rl.gae-lambda}") 71 | private float gaeLambda; 72 | 73 | @Value("${hadoop.hdfs.url}") 74 | private String hdfsUrl; 75 | 76 | @Value("${spring.application.name}") 77 | String name; 78 | 79 | @Autowired 80 | private Model actorModel; 81 | 82 | @Autowired 83 | private Model criticModel; 84 | 85 | @Autowired 86 | private FileSystem fileSystem; 87 | 88 | @Autowired 89 | private TrainingConfig trainingConfig; 90 | 91 | private final Lock lock = new ReentrantLock(); 92 | 93 | @Override 94 | public void afterPropertiesSet() { 95 | } 96 | 97 | public int selectAction(float[] state, int[] availAction, boolean training) { 98 | int action; 99 | try { 100 | lock.lock(); 101 | var subManager = manager.newSubManager(); 102 | try (subManager) { 103 | var predictor = actorModel.newPredictor(new NoopTranslator()); 104 | try (predictor) { 105 | NDArray out = null; 106 | try { 107 | out = predictor.predict(new NDList(subManager.create(state))).singletonOrThrow(); 108 | log.info("{}", out); 109 | } catch (TranslateException e) { 110 | log.error("predict error: {}", e.getMessage()); 111 | } 112 | var bool = subManager.create(availAction).eq(0); 113 | assert out != null; 114 | out.set(bool, -1e8f); 115 | var prob = out.softmax(-1); 116 | log.info("prob: {}", prob); 117 | action = DJLUtils.sampleMultinomial(schedulerRandom, prob); 118 | } 119 | } 120 | } finally { 121 | lock.unlock(); 122 | } 123 | return action; 124 | } 125 | 126 | public void train() { 127 | try { 128 | lock.lock(); 129 | var subManager = manager.newSubManager(); 130 | try (subManager) { 131 | var list = buffer.sampleAll(subManager); 132 | var states = list.get(0); 133 | var actions = list.get(1); 134 | var availActions = list.get(2); 135 | var rewards = list.get(3); 136 | var nextStates = list.get(4); 137 | // Trick: normalized reward 138 | if (useNormalizedReward) { 139 | var mean = rewards.mean(); 140 | var std = rewards.sub(mean).pow(2).mean().sqrt().add(1e-5f); 141 | rewards = rewards.sub(mean).div(std); 142 | } 143 | var criticTrainer = criticModel.newTrainer(trainingConfig); 144 | var actorTrainer = actorModel.newTrainer(trainingConfig); 145 | try (criticTrainer; actorTrainer) { 146 | var stateValues = criticTrainer.evaluate(new NDList(states)).singletonOrThrow(); 147 | var nextStateValues = criticTrainer.evaluate(new NDList(nextStates)).singletonOrThrow(); 148 | // gae 149 | NDArray advantages; 150 | if (useGae) { 151 | advantages = DJLUtils.getGae(rewards, stateValues, nextStateValues, subManager, gamma, episodeLimit, gaeLambda); 152 | } else { 153 | var returns = DJLUtils.getReturn(rewards, nextStateValues, gamma, episodeLimit); 154 | advantages = returns.sub(stateValues); 155 | } 156 | var targets = advantages.add(stateValues); 157 | var out = actorTrainer.evaluate(new NDList(states)).singletonOrThrow(); 158 | var index1 = availActions.eq(0); 159 | out.set(index1, -1e5f); 160 | var logProb = out.logSoftmax(-1); 161 | var oldLogProbTaken = logProb.gather(actions, -1); 162 | for (int i = 0; i < epoch; i++) { 163 | var gradientCollector = actorTrainer.newGradientCollector(); 164 | try (gradientCollector) { 165 | var output = actorTrainer.forward(new NDList(states)).singletonOrThrow(); 166 | var index = availActions.eq(0); 167 | output.set(index, -1e5f); 168 | var probabilities = output.softmax(-1); 169 | // attention!!! 170 | var logProbabilities = output.logSoftmax(-1); 171 | var logProbTaken = logProbabilities.gather(actions, -1); 172 | // PPO 173 | var ratios = logProbTaken.sub(oldLogProbTaken).exp(); 174 | // 175 | var surr1 = ratios.mul(advantages); 176 | var surr2 = ratios.clip(1 - clip, 1 + clip).mul(advantages); 177 | // maximizing 178 | var loss = NDArrays.minimum(surr1, surr2).mean().neg(); 179 | // Trick: entropy 180 | if (useEntropy) { 181 | // entropy definition: 182 | var entropy = probabilities.mul(logProbabilities).neg(); 183 | // maximizing entropy 184 | var entropyLoss = entropy.mean().mul(entropyCoef).neg(); 185 | loss = loss.add(entropyLoss); 186 | } 187 | gradientCollector.backward(loss); 188 | actorTrainer.step(); 189 | } 190 | var criticGradientCollector = criticTrainer.newGradientCollector(); 191 | try (criticGradientCollector) { 192 | var newValues = criticTrainer.forward(new NDList(states)).singletonOrThrow(); 193 | var criticLoss = newValues.sub(targets).pow(2).mean().mul(0.5f); 194 | criticGradientCollector.backward(criticLoss); 195 | criticTrainer.step(); 196 | } 197 | } 198 | } 199 | } 200 | } finally { 201 | lock.unlock(); 202 | } 203 | } 204 | 205 | public void saveModel(String flag) { 206 | String basePath = "results/model/" + flag; 207 | try { 208 | actorModel.save(Paths.get(basePath), null); 209 | criticModel.save(Paths.get(basePath), null); 210 | } catch (IOException e) { 211 | log.info("save model: {}", e.getMessage()); 212 | } 213 | } 214 | 215 | public void loadModel(String flag) { 216 | String basePath = "results/model/" + flag; 217 | try { 218 | actorModel.load(Paths.get(basePath)); 219 | actorModel.load(Paths.get(basePath)); 220 | } catch (IOException | MalformedModelException e) { 221 | log.info("load model: {}", e.getMessage()); 222 | } 223 | } 224 | 225 | public void saveHdfsModel(String flag) { 226 | String basePath = "results/model/" + flag; 227 | try { 228 | actorModel.save(Paths.get(basePath), null); 229 | criticModel.save(Paths.get(basePath), null); 230 | } catch (IOException e) { 231 | log.error("save model error: {}", e.getMessage()); 232 | } 233 | var actorPath = basePath + "/actor-0000.params"; 234 | var criticPath = basePath + "/critic-0000.params"; 235 | try { 236 | if (fileSystem.exists(new Path(hdfsUrl + "/" + basePath))) { 237 | fileSystem.delete(new Path(hdfsUrl + "/" + basePath), true); 238 | } 239 | } catch (IOException e) { 240 | log.error("delete file in file system error: {}", e.getMessage()); 241 | } 242 | try { 243 | fileSystem.copyFromLocalFile(true, true, new Path(actorPath), new Path(hdfsUrl + "/" + actorPath)); 244 | fileSystem.copyFromLocalFile(true, true, new Path(criticPath), new Path(hdfsUrl + "/" + criticPath)); 245 | } catch (IOException e) { 246 | log.error("file system save file error: {}", e.getMessage()); 247 | } 248 | } 249 | 250 | public void loadHdfsModel(String flag) { 251 | String basePath = "results/model/" + flag; 252 | try { 253 | var actorPath = basePath + "/actor-0000.params"; 254 | var criticPath = basePath + "/critic-0000.params"; 255 | fileSystem.copyToLocalFile(false, new Path(hdfsUrl + "/" + actorPath), new Path(actorPath), true); 256 | fileSystem.copyToLocalFile(false, new Path(hdfsUrl + "/" + criticPath), new Path(criticPath), true); 257 | } catch (IOException e) { 258 | log.info("load hdfs model: {}", e.getMessage()); 259 | } 260 | try { 261 | actorModel.load(Paths.get(basePath)); 262 | actorModel.load(Paths.get(basePath)); 263 | } catch (IOException | MalformedModelException e) { 264 | throw new RuntimeException(e); 265 | } 266 | var parameters = actorModel.getBlock().getParameters(); 267 | for (String key : parameters.keys()) { 268 | parameters.get(key).getArray().setRequiresGradient(true); 269 | } 270 | var parameters2 = criticModel.getBlock().getParameters(); 271 | for (String key : parameters2.keys()) { 272 | parameters.get(key).getArray().setRequiresGradient(true); 273 | } 274 | FileUtils.recursiveDelete(Paths.get(basePath).toFile()); 275 | } 276 | 277 | public void loadSteamModel(InputStream inputStream, String fileName) { 278 | switch (fileName) { 279 | case "actor.param" -> { 280 | try { 281 | actorModel.load(inputStream); 282 | } catch (IOException | MalformedModelException e) { 283 | log.info("load stream model error: {}", e.getMessage()); 284 | } 285 | } 286 | case "critic.param" -> { 287 | try { 288 | criticModel.load(inputStream); 289 | } catch (IOException | MalformedModelException e) { 290 | log.info("load stream model error: {}", e.getMessage()); 291 | } 292 | } 293 | default -> throw new RuntimeException("fileName error!"); 294 | } 295 | } 296 | } 297 | -------------------------------------------------------------------------------- /edge-algorithm/src/main/java/cn/edu/scut/agent/mappo/MAPPOConfig.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut.agent.mappo; 2 | 3 | import ai.djl.Model; 4 | import ai.djl.ndarray.NDArray; 5 | import ai.djl.ndarray.NDList; 6 | import ai.djl.ndarray.NDManager; 7 | import ai.djl.ndarray.types.DataType; 8 | import ai.djl.ndarray.types.Shape; 9 | import ai.djl.nn.Activation; 10 | import ai.djl.nn.Block; 11 | import ai.djl.nn.SequentialBlock; 12 | import ai.djl.nn.core.Linear; 13 | import ai.djl.training.DefaultTrainingConfig; 14 | import ai.djl.training.TrainingConfig; 15 | import ai.djl.training.loss.Loss; 16 | import ai.djl.training.optimizer.Optimizer; 17 | import ai.djl.training.tracker.PolynomialDecayTracker; 18 | import ai.djl.training.tracker.Tracker; 19 | import org.springframework.beans.factory.annotation.Value; 20 | import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; 21 | import org.springframework.context.annotation.Bean; 22 | import org.springframework.context.annotation.Configuration; 23 | 24 | @Configuration 25 | @ConditionalOnProperty(name = "rl.name", havingValue = "mappo") 26 | public class MAPPOConfig { 27 | @Value("${rl.epoch}") 28 | private int epoch; 29 | 30 | @Value("${rl.start-learning-rate}") 31 | private float startLearningRate; 32 | 33 | @Value("${rl.end-learning-rate}") 34 | private float endLearningRate; 35 | 36 | @Value("${rl.use-learning-rate-decay}") 37 | private boolean useLearningRateDecay; 38 | 39 | @Value("${rl.learning-rate}") 40 | private float learningRate; 41 | 42 | @Value("${rl.use-clip-grad}") 43 | private boolean useClipGrad; 44 | 45 | @Value("${rl.clip-grad-coef}") 46 | private float clipGradCoef; 47 | 48 | @Value("${rl.hidden-shape}") 49 | private int hiddenShape; 50 | 51 | @Value("${rl.action-shape}") 52 | private int actionShape; 53 | 54 | @Value("${rl.state-shape}") 55 | private int stateShape; 56 | 57 | @Value("${edgeComputing.episodeNumber}") 58 | private int episodeNumber; 59 | 60 | @Value("${rl.batch-size}") 61 | private int batchSize; 62 | 63 | @Value("${edgeComputing.edgeNodeNumber}") 64 | private int agentNumber; 65 | 66 | @Bean 67 | public Tracker polynomialDecayTracker() { 68 | int trainNum = episodeNumber * epoch * 2; 69 | return PolynomialDecayTracker.builder() 70 | .setBaseValue(startLearningRate) 71 | .setEndLearningRate(endLearningRate) 72 | .setDecaySteps(trainNum) // share by actor and critic 73 | .build(); 74 | } 75 | 76 | @Bean 77 | public Optimizer optimizer(Tracker polynomialDecayTracker) { 78 | var adam = Optimizer.adam(); 79 | if (useLearningRateDecay) { 80 | adam.optLearningRateTracker(polynomialDecayTracker); 81 | } else { 82 | adam.optLearningRateTracker(Tracker.fixed(learningRate)); 83 | } 84 | if (useClipGrad) { 85 | adam.optClipGrad(clipGradCoef); 86 | } 87 | return adam.build(); 88 | } 89 | 90 | public Block createNetwork(NDManager manager, int outputDim) { 91 | var block = new SequentialBlock(); 92 | block.add(Linear.builder().setUnits(hiddenShape).build()); 93 | block.add(Activation::relu); 94 | block.add(Linear.builder().setUnits(hiddenShape).build()); 95 | block.add(Activation::relu); 96 | block.add(Linear.builder().setUnits(outputDim).build()); 97 | block.initialize(manager, DataType.FLOAT32, new Shape(batchSize, agentNumber, stateShape)); 98 | return block; 99 | } 100 | 101 | @Bean 102 | public Model actorModel(NDManager manager) { 103 | var model = Model.newInstance("actor"); 104 | model.setBlock(createNetwork(manager, actionShape)); 105 | return model; 106 | } 107 | 108 | @Bean 109 | public Model criticModel(NDManager manager) { 110 | var model = Model.newInstance("critic"); 111 | model.setBlock(createNetwork(manager, 1)); 112 | return model; 113 | } 114 | 115 | @Bean 116 | public Loss loss() { 117 | return new Loss("null") { 118 | @Override 119 | public NDArray evaluate(NDList ndList, NDList ndList1) { 120 | return null; 121 | } 122 | }; 123 | } 124 | 125 | @Bean 126 | public TrainingConfig trainingConfig(Optimizer optimizer, Loss loss) { 127 | return new DefaultTrainingConfig(loss) 128 | .optOptimizer(optimizer); 129 | } 130 | } 131 | -------------------------------------------------------------------------------- /edge-algorithm/src/main/java/cn/edu/scut/agent/masac/AlphaBlock.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut.agent.masac; 2 | 3 | import ai.djl.ndarray.NDList; 4 | import ai.djl.ndarray.types.Shape; 5 | import ai.djl.nn.AbstractBlock; 6 | import ai.djl.nn.Parameter; 7 | import ai.djl.training.ParameterStore; 8 | import ai.djl.training.initializer.ConstantInitializer; 9 | import ai.djl.util.PairList; 10 | 11 | public class AlphaBlock extends AbstractBlock { 12 | 13 | private Parameter logAlpha; 14 | 15 | public AlphaBlock(float initAlpha) { 16 | logAlpha = addParameter(Parameter.builder() 17 | .setName("alpha") 18 | .setType(Parameter.Type.WEIGHT) 19 | .optShape(new Shape(1)) 20 | .optInitializer(new ConstantInitializer((float) Math.log(initAlpha))) 21 | .optRequiresGrad(true) 22 | .build()); 23 | } 24 | 25 | @Override 26 | protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, PairList params) { 27 | return new NDList(logAlpha.getArray()); 28 | } 29 | 30 | @Override 31 | public Shape[] getOutputShapes(Shape[] inputShapes) { 32 | return new Shape[]{new Shape(1)}; 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /edge-algorithm/src/main/java/cn/edu/scut/agent/masac/MASACAgent.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut.agent.masac; 2 | 3 | import ai.djl.MalformedModelException; 4 | import ai.djl.Model; 5 | import ai.djl.ndarray.NDArray; 6 | import ai.djl.ndarray.NDArrays; 7 | import ai.djl.ndarray.NDList; 8 | import ai.djl.ndarray.NDManager; 9 | import ai.djl.training.TrainingConfig; 10 | import ai.djl.translate.NoopTranslator; 11 | import ai.djl.translate.TranslateException; 12 | import cn.edu.scut.agent.IMAAgent; 13 | import cn.edu.scut.agent.MABuffer; 14 | import cn.edu.scut.util.DJLUtils; 15 | import cn.edu.scut.util.FileUtils; 16 | import lombok.extern.slf4j.Slf4j; 17 | import org.apache.hadoop.fs.FileSystem; 18 | import org.apache.hadoop.fs.Path; 19 | import org.springframework.beans.factory.InitializingBean; 20 | import org.springframework.beans.factory.annotation.Autowired; 21 | import org.springframework.beans.factory.annotation.Value; 22 | import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; 23 | import org.springframework.stereotype.Component; 24 | 25 | import java.io.IOException; 26 | import java.io.InputStream; 27 | import java.nio.file.Files; 28 | import java.nio.file.Paths; 29 | import java.util.Random; 30 | 31 | @Component 32 | @Slf4j 33 | @ConditionalOnProperty(name = "rl.name", havingValue = "masac") 34 | public class MASACAgent implements InitializingBean, IMAAgent { 35 | 36 | @Value("${rl.action-shape}") 37 | private int actionShape; 38 | 39 | @Value("${rl.alpha}") 40 | private float alpha; 41 | 42 | @Value("${rl.gamma}") 43 | private float gamma; 44 | 45 | @Value("${rl.use-soft-update}") 46 | private boolean useSoftUpdate; 47 | 48 | @Value("${rl.tau}") 49 | private float tau; 50 | 51 | @Value("${rl.use-adaptive-alpha}") 52 | private boolean useAdaptiveAlpha; 53 | 54 | @Value("${rl.use-normalized-reward}") 55 | private boolean useNormalizedReward; 56 | 57 | @Autowired 58 | private Random schedulerRandom; 59 | 60 | @Autowired 61 | private Model q1Model; 62 | @Autowired 63 | private Model q2Model; 64 | 65 | @Autowired 66 | private Model targetQ1Model; 67 | @Autowired 68 | private Model targetQ2Model; 69 | 70 | @Autowired 71 | private Model actorModel; 72 | 73 | @Autowired 74 | private Model criticModel; 75 | 76 | @Autowired 77 | private MABuffer buffer; 78 | 79 | @Autowired 80 | private NDManager manager; 81 | 82 | private float targetEntropy; 83 | 84 | @Autowired 85 | private Model alphaModel; 86 | 87 | @Value("${spring.application.name}") 88 | String name; 89 | 90 | @Autowired 91 | private FileSystem fileSystem; 92 | 93 | @Value("${hadoop.hdfs.url}") 94 | private String hdfsUrl; 95 | 96 | @Value("${rl.use-cql}") 97 | private boolean useCql; 98 | 99 | @Value("${rl.cql-weight}") 100 | private float cqlWeight; 101 | 102 | @Value("${rl.use-addition-critic}") 103 | private boolean useAdditionCritic; 104 | 105 | @Value("${rl.target-entropy-coef}") 106 | private float targetEntropyCoef; 107 | 108 | @Autowired 109 | private TrainingConfig trainingConfig; 110 | 111 | @Override 112 | public void afterPropertiesSet() throws Exception { 113 | // TARGET_ENTROPY = (float) Math.log(1.0 / ACTION_SHAPE) * 0.98f; 114 | // We use 0.6 because the recommended 0.98 will cause alpha explosion. 115 | targetEntropy = -(float) Math.log(1.0 / actionShape) * targetEntropyCoef; 116 | } 117 | 118 | public int selectAction(float[] state, int[] availAction, boolean training) { 119 | int action = 0; 120 | var subManager = manager.newSubManager(); 121 | try (subManager) { 122 | var predictor = actorModel.newPredictor(new NoopTranslator()); 123 | try (predictor) { 124 | try { 125 | var out = predictor.predict(new NDList(subManager.create(state))).singletonOrThrow(); 126 | var bool = subManager.create(availAction).eq(0); 127 | out.set(bool, -1e8f); 128 | var prob = out.softmax(-1); 129 | action = DJLUtils.sampleMultinomial(schedulerRandom, prob); 130 | } catch (TranslateException e) { 131 | log.error("predict error: {}", e.getMessage()); 132 | } 133 | } 134 | } 135 | return action; 136 | } 137 | 138 | public void train() { 139 | var subManager = manager.newSubManager(); 140 | try (subManager) { 141 | NDList list = buffer.sample(subManager); 142 | var states = list.get(0); 143 | var actions = list.get(1); 144 | var availActions = list.get(2); 145 | var rewards = list.get(3); 146 | var nextStates = list.get(4); 147 | 148 | if (useNormalizedReward) { 149 | var mean = rewards.mean(); 150 | var std = rewards.sub(mean).pow(2).mean().sqrt().add(1e-5f); 151 | rewards = rewards.sub(mean).div(std); 152 | } 153 | 154 | var q1Trainer = q1Model.newTrainer(trainingConfig); 155 | var q2Trainer = q2Model.newTrainer(trainingConfig); 156 | var actorTrainer = actorModel.newTrainer(trainingConfig); 157 | var criticTrainer = criticModel.newTrainer(trainingConfig); 158 | var targetQ1Predictor = targetQ1Model.newPredictor(new NoopTranslator()); 159 | var targetQ2Predictor = targetQ2Model.newPredictor(new NoopTranslator()); 160 | var alphaTrainer = alphaModel.newTrainer(trainingConfig); 161 | NDArray alphaValue; 162 | if (useAdaptiveAlpha) { 163 | alphaValue = alphaModel.getBlock().getParameters().get("alpha").getArray().duplicate().exp(); 164 | } else { 165 | alphaValue = subManager.create(alpha); 166 | } 167 | try (q1Trainer; q2Trainer; actorTrainer; criticTrainer; targetQ1Predictor; targetQ2Predictor; alphaTrainer) { 168 | var actorOut = actorTrainer.evaluate(new NDList(nextStates)).singletonOrThrow(); 169 | var nextLogProbabilities = actorOut.logSoftmax(-1); 170 | NDArray nextTargetQ1; 171 | NDArray nextTargetQ2; 172 | try { 173 | nextTargetQ1 = targetQ1Predictor.predict(new NDList(nextStates)).singletonOrThrow(); 174 | nextTargetQ2 = targetQ2Predictor.predict(new NDList(nextStates)).singletonOrThrow(); 175 | } catch (TranslateException e) { 176 | throw new RuntimeException(e); 177 | } 178 | var targetQ = nextLogProbabilities.exp().mul(NDArrays.minimum(nextTargetQ1, nextTargetQ2).sub(nextLogProbabilities.mul(alphaValue))); 179 | var avgTargetQ = targetQ.sum(new int[]{-1}, true); 180 | var target = rewards.add(avgTargetQ.mul(gamma)); 181 | var q1Value = q1Trainer.evaluate(new NDList(states)).singletonOrThrow(); 182 | var q2Value = q2Trainer.evaluate(new NDList(states)).singletonOrThrow(); 183 | var lobProbabilityValue = actorTrainer.evaluate(new NDList(states)).singletonOrThrow().logSoftmax(-1); 184 | 185 | if (useAdaptiveAlpha) { 186 | var alphaGradientCollector = alphaTrainer.newGradientCollector(); 187 | try (alphaGradientCollector) { 188 | var entropy = lobProbabilityValue.exp().mul(lobProbabilityValue).sum(new int[]{-1}).mean().neg(); 189 | var logAlpha = alphaModel.getBlock().getParameters().get("alpha").getArray(); 190 | var loss = logAlpha.mul(entropy.sub(targetEntropy)); 191 | alphaGradientCollector.backward(loss); 192 | alphaTrainer.step(); 193 | } 194 | } 195 | 196 | var actorGradientCollector = actorTrainer.newGradientCollector(); 197 | try (actorGradientCollector) { 198 | var qMin = NDArrays.minimum(q1Value, q2Value); 199 | var lobProbabilities = actorTrainer.forward(new NDList(states)).singletonOrThrow().logSoftmax(-1); 200 | var loss = lobProbabilities.exp().mul(qMin.sub(lobProbabilities.mul(alphaValue))).sum(new int[]{-1}).mean().neg(); 201 | actorGradientCollector.backward(loss); 202 | actorTrainer.step(); 203 | } 204 | 205 | var q1GradientCollector = q1Trainer.newGradientCollector(); 206 | try (q1GradientCollector) { 207 | var q1 = q1Trainer.forward(new NDList(states)).singletonOrThrow(); 208 | var q1Action = q1.gather(actions, -1); 209 | var loss1 = q1Action.sub(target).pow(2).mean(); 210 | if (useCql) { 211 | var cqlLoss1 = (q1.exp().sum().log().mean()).sub(q1Action.mean()); 212 | loss1.add(cqlLoss1.mul(cqlWeight)); 213 | } 214 | q1GradientCollector.backward(loss1); 215 | q1Trainer.step(); 216 | } 217 | var q2GradientCollector = q2Trainer.newGradientCollector(); 218 | try (q2GradientCollector) { 219 | var q2 = q2Trainer.forward(new NDList(states)).singletonOrThrow(); 220 | var q2Action = q2.gather(actions, -1); 221 | var loss2 = q2Action.sub(target).pow(2).mean(); 222 | if (useCql) { 223 | var cqlLoss2 = (q2.exp().sum().log().mean()).sub(q2Action.mean()); 224 | loss2.add(cqlLoss2.mul(cqlWeight)); 225 | } 226 | q2GradientCollector.backward(loss2); 227 | q2Trainer.step(); 228 | } 229 | if (useAdditionCritic) { 230 | var criticGradientCollector = criticTrainer.newGradientCollector(); 231 | try (criticGradientCollector) { 232 | var value = criticTrainer.evaluate(new NDList(states)).singletonOrThrow(); 233 | var loss = value.sub(target).pow(2).mean(); 234 | criticGradientCollector.backward(loss); 235 | criticTrainer.step(); 236 | } 237 | } 238 | } 239 | if (useSoftUpdate) { 240 | DJLUtils.softUpdate(q1Model.getBlock(), targetQ1Model.getBlock(), tau); 241 | DJLUtils.softUpdate(q1Model.getBlock(), targetQ2Model.getBlock(), tau); 242 | } 243 | } 244 | } 245 | 246 | public void saveModel(String path) { 247 | String basePath = "results/model/" + path + "/"; 248 | var actorPath = basePath + "actor.param"; 249 | var criticPath = basePath + "critic.param"; 250 | try { 251 | Files.createDirectories(Paths.get(actorPath).getParent()); 252 | } catch (IOException e) { 253 | throw new RuntimeException(e); 254 | } 255 | DJLUtils.saveModel(Paths.get(actorPath), actorModel.getBlock()); 256 | DJLUtils.saveModel(Paths.get(criticPath), criticModel.getBlock()); 257 | } 258 | 259 | public void loadModel(String path) { 260 | String basePath = "results/model/" + path + "/"; 261 | var actorPath = basePath + "actor.param"; 262 | var criticPath = basePath + "critic.param"; 263 | DJLUtils.loadModel(Paths.get(actorPath), actorModel.getBlock(), manager); 264 | DJLUtils.loadModel(Paths.get(criticPath), criticModel.getBlock(), manager); 265 | } 266 | 267 | // edge-experiment 268 | @Override 269 | public void saveHdfsModel(String flag) { 270 | String basePath = "results/model/" + flag; 271 | try { 272 | actorModel.save(Paths.get(basePath), null); 273 | criticModel.save(Paths.get(basePath), null); 274 | } catch (IOException e) { 275 | log.error("save model error: {}", e.getMessage()); 276 | } 277 | var actorPath = basePath + "/actor-0000.params"; 278 | var criticPath = basePath + "/critic-0000.params"; 279 | try { 280 | if (fileSystem.exists(new Path(hdfsUrl + "/" + basePath))) { 281 | fileSystem.delete(new Path(hdfsUrl + "/" + basePath), true); 282 | } 283 | } catch (IOException e) { 284 | log.error("delete file in file system error: {}", e.getMessage()); 285 | } 286 | try { 287 | fileSystem.copyFromLocalFile(true, true, new Path(actorPath), new Path(hdfsUrl + "/" + actorPath)); 288 | fileSystem.copyFromLocalFile(true, true, new Path(criticPath), new Path(hdfsUrl + "/" + criticPath)); 289 | } catch (IOException e) { 290 | log.error("file system save file error: {}", e.getMessage()); 291 | } 292 | } 293 | 294 | // edge-node, edge-experiment 295 | @Override 296 | public void loadHdfsModel(String flag) { 297 | String basePath = "results/model/" + flag; 298 | try { 299 | var actorPath = basePath + "/actor-0000.params"; 300 | var criticPath = basePath + "/critic-0000.params"; 301 | fileSystem.copyToLocalFile(false, new Path(hdfsUrl + "/" + actorPath), new Path(actorPath), true); 302 | fileSystem.copyToLocalFile(false, new Path(hdfsUrl + "/" + criticPath), new Path(criticPath), true); 303 | } catch (IOException e) { 304 | log.info("load hdfs model: {}", e.getMessage()); 305 | } 306 | try { 307 | actorModel.load(Paths.get(basePath)); 308 | actorModel.load(Paths.get(basePath)); 309 | } catch (IOException | MalformedModelException e) { 310 | throw new RuntimeException(e); 311 | } 312 | FileUtils.recursiveDelete(Paths.get(basePath).toFile()); 313 | } 314 | 315 | @Override 316 | public void loadSteamModel(InputStream inputStream, String fileName) { 317 | switch (fileName) { 318 | case "actor.param" -> { 319 | try { 320 | actorModel.load(inputStream); 321 | } catch (IOException | MalformedModelException e) { 322 | log.info("load stream model error: {}", e.getMessage()); 323 | } 324 | } 325 | case "critic.param" -> { 326 | try { 327 | criticModel.load(inputStream); 328 | } catch (IOException | MalformedModelException e) { 329 | log.info("load stream model error: {}", e.getMessage()); 330 | } 331 | } 332 | default -> throw new RuntimeException("fileName error!"); 333 | } 334 | } 335 | } 336 | -------------------------------------------------------------------------------- /edge-algorithm/src/main/java/cn/edu/scut/agent/masac/MASACConfiguration.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut.agent.masac; 2 | 3 | import ai.djl.Model; 4 | import ai.djl.ndarray.NDArray; 5 | import ai.djl.ndarray.NDList; 6 | import ai.djl.ndarray.NDManager; 7 | import ai.djl.ndarray.types.DataType; 8 | import ai.djl.ndarray.types.Shape; 9 | import ai.djl.nn.Activation; 10 | import ai.djl.nn.Block; 11 | import ai.djl.nn.SequentialBlock; 12 | import ai.djl.nn.core.Linear; 13 | import ai.djl.training.DefaultTrainingConfig; 14 | import ai.djl.training.TrainingConfig; 15 | import ai.djl.training.loss.Loss; 16 | import ai.djl.training.optimizer.Optimizer; 17 | import ai.djl.training.tracker.Tracker; 18 | import lombok.extern.slf4j.Slf4j; 19 | import org.springframework.beans.factory.annotation.Value; 20 | import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; 21 | import org.springframework.context.annotation.Bean; 22 | import org.springframework.context.annotation.Configuration; 23 | 24 | @Configuration 25 | @Slf4j 26 | @ConditionalOnProperty(name = "rl.name", havingValue = "masac") 27 | public class MASACConfiguration { 28 | @Value("${rl.state-shape}") 29 | private int stateShape; 30 | 31 | @Value("${rl.hidden-shape}") 32 | private int hiddenShape; 33 | 34 | @Value("${rl.action-shape}") 35 | private int actionShape; 36 | 37 | @Value("${rl.learning-rate}") 38 | private float learningRate; 39 | 40 | @Value("${rl.alpha}") 41 | private float alpha; 42 | 43 | public Block createNetwork(NDManager manager, int outputDim) { 44 | var block = new SequentialBlock(); 45 | block.add(Linear.builder().setUnits(hiddenShape).build()); 46 | block.add(Activation::relu); 47 | block.add(Linear.builder().setUnits(hiddenShape).build()); 48 | block.add(Activation::relu); 49 | block.add(Linear.builder().setUnits(outputDim).build()); 50 | block.initialize(manager, DataType.FLOAT32, new Shape(stateShape)); 51 | return block; 52 | } 53 | 54 | @Bean 55 | public Optimizer optimizer() { 56 | return Optimizer.adam().optLearningRateTracker(Tracker.fixed(learningRate)).build(); 57 | } 58 | 59 | @Bean 60 | public Model q1Model(NDManager manager) { 61 | var model = Model.newInstance("q1"); 62 | model.setBlock(createNetwork(manager, actionShape)); 63 | return model; 64 | } 65 | 66 | @Bean 67 | public Model q2Model(NDManager manager) { 68 | var model = Model.newInstance("q2"); 69 | model.setBlock(createNetwork(manager, actionShape)); 70 | return model; 71 | } 72 | 73 | @Bean 74 | public Model targetQ1Model(NDManager manager) { 75 | var model = Model.newInstance("targetQ1"); 76 | model.setBlock(createNetwork(manager, actionShape)); 77 | return model; 78 | } 79 | 80 | @Bean 81 | public Model targetQ2Model(NDManager manager) { 82 | var model = Model.newInstance("targetQ2"); 83 | model.setBlock(createNetwork(manager, actionShape)); 84 | return model; 85 | } 86 | 87 | @Bean 88 | public Model criticModel(NDManager manager) { 89 | var model = Model.newInstance("critic"); 90 | model.setBlock(createNetwork(manager, 1)); 91 | return model; 92 | } 93 | 94 | @Bean 95 | public Model actorModel(NDManager manager) { 96 | var model = Model.newInstance("actor"); 97 | model.setBlock(createNetwork(manager, actionShape)); 98 | return model; 99 | } 100 | 101 | @Bean 102 | public Model alphaModel(NDManager manager) { 103 | var model = Model.newInstance("alpha"); 104 | var block = new AlphaBlock(alpha); 105 | block.initialize(manager, DataType.FLOAT32, new Shape(1)); 106 | model.setBlock(block); 107 | return model; 108 | } 109 | 110 | @Bean 111 | public Loss loss() { 112 | return new Loss("null") { 113 | @Override 114 | public NDArray evaluate(NDList ndList, NDList ndList1) { 115 | return null; 116 | } 117 | }; 118 | } 119 | 120 | @Bean 121 | public TrainingConfig trainingConfig(Optimizer optimizer, Loss loss) { 122 | return new DefaultTrainingConfig(loss) 123 | .optOptimizer(optimizer); 124 | } 125 | } 126 | -------------------------------------------------------------------------------- /edge-algorithm/src/main/java/cn/edu/scut/config/DJLConfig.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut.config; 2 | 3 | import ai.djl.engine.Engine; 4 | import ai.djl.ndarray.NDManager; 5 | import org.springframework.beans.factory.annotation.Value; 6 | import org.springframework.context.annotation.Bean; 7 | import org.springframework.context.annotation.Configuration; 8 | 9 | import java.util.Random; 10 | 11 | @Configuration 12 | public class DJLConfig { 13 | 14 | @Value("${edgeComputing.seed}") 15 | Integer seed; 16 | 17 | @Value("${spring.application.name}") 18 | String name; 19 | 20 | @Bean 21 | public Engine engine() { 22 | Engine engine = Engine.getInstance(); 23 | engine.setRandomSeed(seed); 24 | return engine; 25 | } 26 | 27 | @Bean 28 | public NDManager manager(Engine engine) { 29 | return engine.newBaseManager(); 30 | } 31 | 32 | @Bean 33 | public Random bufferRandom(){ 34 | var random = new Random(); 35 | random.setSeed(seed); 36 | return random; 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /edge-algorithm/src/main/java/cn/edu/scut/service/TransitionService.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut.service; 2 | 3 | import cn.edu.scut.bean.*; 4 | import cn.edu.scut.util.ArrayUtils; 5 | import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; 6 | import lombok.extern.slf4j.Slf4j; 7 | import org.springframework.beans.factory.annotation.Autowired; 8 | import org.springframework.beans.factory.annotation.Value; 9 | import org.springframework.stereotype.Service; 10 | 11 | import java.util.ArrayList; 12 | import java.util.Arrays; 13 | 14 | @Service 15 | @Slf4j 16 | public class TransitionService { 17 | 18 | @Value("${edgeComputing.maxTaskRate}") 19 | private double maxTaskRate; 20 | 21 | @Value("${edgeComputing.maxCpuCore}") 22 | private int maxCpuCore; 23 | 24 | @Value("${edgeComputing.maxTaskSize}") 25 | private long maxTaskSize; 26 | 27 | @Value("${edgeComputing.maxTaskComplexity}") 28 | private long maxTaskComplexity; 29 | 30 | @Value("${edgeComputing.maxTransmissionRate}") 31 | private double maxTransmissionRate; 32 | 33 | @Value("${edgeComputing.maxTransmissionFailureRate}") 34 | private double maxTransmissionFailureRate; 35 | 36 | @Value("${edgeComputing.maxExecutionFailureRate}") 37 | private double maxExecutionFailureRate; 38 | 39 | @Value("${edgeComputing.deadline}") 40 | private int deadline; 41 | 42 | @Value("${edgeComputing.edgeNodeNumber}") 43 | private int agentNumber; 44 | 45 | @Autowired 46 | private EdgeNodeService edgeNodeService; 47 | 48 | @Autowired 49 | private TaskService taskService; 50 | 51 | @Autowired 52 | private LinkService linkService; 53 | 54 | public float[] getState(Long taskId) { 55 | var task = taskService.getById(taskId); 56 | var edgeNodeInfos = edgeNodeService.list(); 57 | var linkInfos = linkService.list(new QueryWrapper().eq("source", task.getSource())); 58 | 59 | var obsList = new ArrayList(); 60 | // one-hot vector N 61 | int id = Integer.parseInt(task.getSource().split("-")[2]) - 1; 62 | for (int j = 0; j < agentNumber; j++) { 63 | if (j == id) { 64 | obsList.add(1.0f); 65 | } else { 66 | obsList.add(0.0f); 67 | } 68 | } 69 | // task dynamic 2 70 | obsList.add(Float.valueOf(task.getTaskSize()) / (float) (maxTaskSize * StoreConstants.Byte.value * StoreConstants.Kilo.value)); 71 | obsList.add(task.getTaskComplexity() / (float) maxTaskComplexity); 72 | // link static N 73 | for (Link link : linkInfos) { 74 | obsList.add((float) (link.getTransmissionRate() / (maxTransmissionRate * Constants.Mega.value * Constants.Byte.value))); 75 | } 76 | // edge node static 3N 77 | // static information 78 | for (EdgeNode edgeNode : edgeNodeInfos) { 79 | obsList.add((float) (edgeNode.getCpuNum() / maxCpuCore)); 80 | obsList.add((float) (edgeNode.getExecutionFailureRate() / maxExecutionFailureRate)); 81 | obsList.add((float) (edgeNode.getTaskRate() / maxTaskRate)); 82 | } 83 | return ArrayUtils.toFloatArray(obsList); 84 | } 85 | 86 | 87 | public int getAction(Long taskId) { 88 | var task = taskService.getById(taskId); 89 | int action; 90 | if (task.getStatus().equals(TaskStatus.EMPTY)) { 91 | action = agentNumber; 92 | } else { 93 | action = Integer.parseInt(task.getDestination().split("-")[2]) - 1; 94 | } 95 | return action; 96 | } 97 | 98 | public int[] getAvailAction(Long taskId) { 99 | var task = taskService.getById(taskId); 100 | return Arrays.stream(task.getAvailAction().split(",")).mapToInt(Integer::parseInt).toArray(); 101 | } 102 | 103 | public float getReward(Long taskId) { 104 | var task = taskService.getById(taskId); 105 | float reward; 106 | if (task.getStatus().equals(TaskStatus.SUCCESS)) { 107 | reward = 1.0f; 108 | } else if (task.getStatus().equals(TaskStatus.EXECUTION_FAILURE) || task.getStatus().equals(TaskStatus.TRANSMISSION_FAILURE) || task.getStatus().equals(TaskStatus.DROP)) { 109 | reward = -1.0f; 110 | } else if (task.getStatus().equals(TaskStatus.EMPTY)) { 111 | return 0.0f; 112 | } else { 113 | throw new RuntimeException("task status error"); 114 | } 115 | return reward; 116 | } 117 | } -------------------------------------------------------------------------------- /edge-algorithm/src/main/java/cn/edu/scut/util/DJLUtils.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut.util; 2 | 3 | import ai.djl.MalformedModelException; 4 | import ai.djl.ndarray.NDArray; 5 | import ai.djl.ndarray.NDManager; 6 | import ai.djl.ndarray.index.NDIndex; 7 | import ai.djl.nn.Block; 8 | import ai.djl.nn.ParameterList; 9 | import lombok.extern.slf4j.Slf4j; 10 | 11 | import java.io.*; 12 | import java.nio.file.Files; 13 | import java.nio.file.Path; 14 | import java.util.List; 15 | import java.util.Random; 16 | 17 | @Slf4j 18 | public class DJLUtils { 19 | public static int sampleMultinomial(Random random, NDArray prob) { 20 | int value = 0; 21 | long size = prob.size(); 22 | float rnd = random.nextFloat(); 23 | for (int i = 0; i < size; i++) { 24 | float cut = prob.getFloat(value); 25 | if (rnd > cut) { 26 | value++; 27 | } else { 28 | return value; 29 | } 30 | rnd -= cut; 31 | } 32 | throw new IllegalArgumentException("Invalid multinomial distribution"); 33 | } 34 | 35 | public static void softUpdate(Block network, Block targetNetwork, float tau) { 36 | var parameters = network.getParameters(); 37 | var targetParameters = targetNetwork.getParameters(); 38 | for (String key : parameters.keys()) { 39 | var array = parameters.get(key).getArray(); 40 | var targetArray = targetParameters.get(key).getArray(); 41 | var mixArray = targetArray.mul(1.0f - tau).add(array.mul(tau)); 42 | targetParameters.get(key).getArray().set(mixArray.toFloatArray()); 43 | } 44 | } 45 | 46 | public static void hardUpdate(Block network, Block targetNetwork) { 47 | var parameters = network.getParameters(); 48 | var targetParameters = targetNetwork.getParameters(); 49 | for (String key : parameters.keys()) { 50 | targetParameters.get(key).getArray().set(parameters.get(key).getArray().toFloatArray()); 51 | } 52 | } 53 | 54 | 55 | public static void saveModel(Path path, Block actorBlock) { 56 | try { 57 | Files.createDirectories(path.getParent()); 58 | } catch (IOException e) { 59 | throw new RuntimeException(e); 60 | } 61 | try (var fileOutputStream = new FileOutputStream(path.toFile()); 62 | var dataOutputStream = new DataOutputStream(fileOutputStream)) { 63 | actorBlock.saveParameters(dataOutputStream); 64 | } catch (IOException e) { 65 | log.error("save model error"); 66 | throw new RuntimeException(e); 67 | } 68 | } 69 | 70 | public static void loadModel(Path path, Block actorBlock, NDManager manager) { 71 | try (var fileInputStream = new FileInputStream(path.toFile()); 72 | var dataInputStream = new DataInputStream(fileInputStream)) { 73 | actorBlock.loadParameters(manager, dataInputStream); 74 | ParameterList parameters = actorBlock.getParameters(); 75 | List keys = parameters.keys(); 76 | for (String key : keys) { 77 | var data = parameters.get(key).getArray(); 78 | data.setRequiresGradient(true); 79 | } 80 | } catch (IOException e) { 81 | log.error("read file error"); 82 | throw new RuntimeException(e); 83 | } catch (MalformedModelException e) { 84 | log.error("load model error"); 85 | throw new RuntimeException(e); 86 | } 87 | } 88 | 89 | public static void loadStreamModel(InputStream inputStream, Block actorBlock, NDManager manager) { 90 | try (var dataInputStream = new DataInputStream(inputStream)) { 91 | log.info("load stream model."); 92 | actorBlock.loadParameters(manager, dataInputStream); 93 | ParameterList parameters = actorBlock.getParameters(); 94 | List keys = parameters.keys(); 95 | for (String key : keys) { 96 | var data = parameters.get(key).getArray(); 97 | data.setRequiresGradient(true); 98 | } 99 | } catch (IOException e) { 100 | log.error("IOException"); 101 | throw new RuntimeException(e); 102 | } catch (MalformedModelException e) { 103 | log.error("MalformedModelException"); 104 | throw new RuntimeException(e); 105 | } 106 | } 107 | 108 | // edge computing, truncate episode, no done 109 | public static NDArray getReturn(NDArray rewards, NDArray nextStateValues, float gamma, int episodeLimit) { 110 | var res = rewards.zerosLike(); 111 | for (long i = episodeLimit - 1; i >= 0; i--) { 112 | NDArray nextReturn; 113 | if (i == episodeLimit - 1) { 114 | nextReturn = nextStateValues.get(episodeLimit - 1); 115 | } else { 116 | nextReturn = res.get(i + 1); 117 | } 118 | var currentReturn = rewards.get(i).add(nextReturn.mul(gamma)); 119 | res.set(new NDIndex(i), currentReturn); 120 | } 121 | return res; 122 | } 123 | 124 | public static NDArray getReturn(NDArray rewards, float gamma, int episodeLimit) { 125 | var res = rewards.zerosLike(); 126 | for (long i = episodeLimit - 1; i >= 0; i--) { 127 | NDArray nextReturn; 128 | if (i == episodeLimit - 1) { 129 | nextReturn = rewards.get(i); 130 | } else { 131 | nextReturn = res.get(i + 1); 132 | } 133 | var currentReturn = rewards.get(i).add(nextReturn.mul(gamma)); 134 | res.set(new NDIndex(i), currentReturn); 135 | } 136 | return res; 137 | } 138 | 139 | 140 | public static NDArray getGae(NDArray rewards, NDArray stateValues, NDArray nextStateValues, NDManager manager, float gamma, int episodeLimit, float gaeLambda) { 141 | var advantages = rewards.zerosLike(); 142 | var deltas = rewards.add(nextStateValues.mul(gamma)).sub(stateValues); 143 | var nextAdvantage = manager.create(0.0f); 144 | for (long i = episodeLimit - 1; i >= 0; i--) { 145 | var advantage = deltas.get(i).add(nextAdvantage.mul(gamma).mul(gaeLambda)); 146 | advantages.set(new NDIndex(i), advantage); 147 | nextAdvantage = advantage; 148 | } 149 | return advantages; 150 | } 151 | } 152 | -------------------------------------------------------------------------------- /edge-algorithm/src/main/java/cn/edu/scut/util/FileUtils.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut.util; 2 | 3 | import lombok.extern.slf4j.Slf4j; 4 | 5 | import java.io.*; 6 | import java.nio.file.Files; 7 | import java.nio.file.Path; 8 | 9 | @Slf4j 10 | public class FileUtils { 11 | public static void recursiveDelete(File file) { 12 | if (!file.exists()) 13 | return; 14 | 15 | if (file.isDirectory()) { 16 | for (File f : file.listFiles()) { 17 | // 调用递归 18 | recursiveDelete(f); 19 | } 20 | } 21 | file.delete(); 22 | log.info("delete file/folder: {}", file.getAbsolutePath()); 23 | } 24 | 25 | public static void writeObject(Object data, Path path) { 26 | try { 27 | Files.createDirectories(path.getParent()); 28 | } catch (IOException e) { 29 | throw new RuntimeException(e); 30 | } 31 | try (FileOutputStream f = new FileOutputStream(path.toFile()); 32 | ObjectOutput s = new ObjectOutputStream(f)) { 33 | s.writeObject(data); 34 | } catch (IOException e) { 35 | throw new RuntimeException(e); 36 | } 37 | } 38 | 39 | public static Object readObject(Path path){ 40 | try (FileInputStream in = new FileInputStream(path.toFile()); 41 | ObjectInputStream s = new ObjectInputStream(in)) { 42 | return s.readObject(); 43 | } catch (IOException | ClassNotFoundException e) { 44 | throw new RuntimeException(e); 45 | } 46 | } 47 | 48 | } 49 | -------------------------------------------------------------------------------- /edge-algorithm/src/main/java/cn/edu/scut/util/PlotUtils.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut.util; 2 | 3 | import tech.tablesaw.plotly.components.Axis; 4 | import tech.tablesaw.plotly.components.Figure; 5 | import tech.tablesaw.plotly.components.Layout; 6 | import tech.tablesaw.plotly.traces.ScatterTrace; 7 | 8 | public class PlotUtils { 9 | public static Figure plot(double[][] x, double[][] y, String[] traceLabels, String xLabel, String yLabel) { 10 | ScatterTrace[] traces = new ScatterTrace[x.length]; 11 | for (int i = 0; i < traces.length; i++) { 12 | traces[i] = 13 | ScatterTrace.builder(x[i], y[i]) 14 | .mode(ScatterTrace.Mode.LINE) 15 | .name(traceLabels[i]) 16 | .build(); 17 | } 18 | Layout layout = 19 | Layout.builder() 20 | .showLegend(true) 21 | .xAxis(Axis.builder().title(xLabel).build()) 22 | .yAxis(Axis.builder().title(yLabel).build()) 23 | .build(); 24 | return new Figure(layout, traces); 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /edge-api/README.md: -------------------------------------------------------------------------------- 1 | DAO 2 | utils -------------------------------------------------------------------------------- /edge-api/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 6 | edge-computing 7 | cn.edu.scut 8 | 1.0-SNAPSHOT 9 | 10 | 4.0.0 11 | 12 | edge-api 13 | 14 | 15 | 17 16 | 17 17 | UTF-8 18 | 19 | 20 | 21 | 22 | com.alibaba.cloud 23 | spring-cloud-starter-alibaba-nacos-discovery 24 | 25 | 26 | 27 | mysql 28 | mysql-connector-java 29 | 30 | 31 | com.alibaba 32 | druid-spring-boot-starter 33 | 34 | 35 | com.baomidou 36 | mybatis-plus-boot-starter 37 | 38 | 39 | 40 | org.springframework.boot 41 | spring-boot-starter-web 42 | 43 | 44 | ai.djl 45 | api 46 | 47 | 48 | 49 | org.projectlombok 50 | lombok 51 | true 52 | 53 | 54 | 55 | org.apache.hadoop 56 | hadoop-common 57 | 58 | 59 | 60 | org.apache.hadoop 61 | hadoop-hdfs 62 | 63 | 64 | 65 | org.apache.hadoop 66 | hadoop-hdfs-client 67 | 68 | 69 | 70 | -------------------------------------------------------------------------------- /edge-api/src/main/java/cn/edu/scut/bean/Constants.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut.bean; 2 | 3 | 4 | public enum Constants { 5 | 6 | Byte(8), Kilo(1000), Mega(1000 * 1000), Giga(1000 * 1000 * 1000); 7 | 8 | public Integer value; 9 | 10 | Constants(Integer value) { 11 | this.value = value; 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /edge-api/src/main/java/cn/edu/scut/bean/EdgeNode.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut.bean; 2 | 3 | import com.baomidou.mybatisplus.annotation.TableField; 4 | import lombok.Data; 5 | 6 | @Data 7 | public class EdgeNode { 8 | private Integer id; 9 | private String name; 10 | private Long cpuNum; 11 | @TableField(exist = false) 12 | private Long capacity; // capacity = cpuNum * cpuCapacity 13 | private Double taskRate; 14 | private Double executionFailureRate; 15 | } -------------------------------------------------------------------------------- /edge-api/src/main/java/cn/edu/scut/bean/Link.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut.bean; 2 | 3 | import lombok.Data; 4 | 5 | @Data 6 | public class Link { 7 | private Integer id; 8 | private String source; 9 | private String destination; 10 | private Double transmissionRate; 11 | private Double transmissionFailureRate; 12 | } 13 | -------------------------------------------------------------------------------- /edge-api/src/main/java/cn/edu/scut/bean/RatcVo.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut.bean; 2 | 3 | import lombok.Data; 4 | 5 | @Data 6 | public class RatcVo { 7 | private String edgeId; 8 | private Double executionFailureRate; 9 | private Long capacity; 10 | private Long waitingTime; 11 | private Long totalTime; 12 | } 13 | -------------------------------------------------------------------------------- /edge-api/src/main/java/cn/edu/scut/bean/StoreConstants.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut.bean; 2 | 3 | public enum StoreConstants { 4 | 5 | Byte(8), Kilo(1024), Mega(1024 * 1024), Giga(1024 * 1024 * 1024); 6 | 7 | public Integer value; 8 | 9 | StoreConstants(Integer value) { 10 | this.value = value; 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /edge-api/src/main/java/cn/edu/scut/bean/Task.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut.bean; 2 | 3 | import com.baomidou.mybatisplus.annotation.TableField; 4 | import lombok.Data; 5 | 6 | import java.time.LocalDateTime; 7 | 8 | @Data 9 | public class Task { 10 | private Long id; 11 | private Integer timeSlot; 12 | private String source; 13 | private String destination; 14 | private TaskStatus status; 15 | // KB 16 | private Long taskSize; 17 | private Long taskComplexity; 18 | // cycle 19 | private Long cpuCycle; // cpuCycle = taskSize * taskComplexity 20 | // s 21 | private Long deadline; 22 | private Long transmissionTime; 23 | private Long executionTime; 24 | private Long transmissionWaitingTime; 25 | private Long executionWaitingTime; 26 | 27 | @TableField(exist = false) 28 | private LocalDateTime arrivalTime; 29 | @TableField(exist = false) 30 | private LocalDateTime beginTransmissionTime; 31 | @TableField(exist = false) 32 | private LocalDateTime endTransmissionTime; 33 | @TableField(exist = false) 34 | private LocalDateTime beginExecutionTime; 35 | @TableField(exist = false) 36 | private LocalDateTime endExecutionTime; 37 | @TableField(exist = false) 38 | private Double transmissionFailureRate; 39 | @TableField(exist = false) 40 | private Double executionFailureRate; 41 | 42 | private String availAction; 43 | } 44 | -------------------------------------------------------------------------------- /edge-api/src/main/java/cn/edu/scut/bean/TaskStatus.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut.bean; 2 | 3 | public enum TaskStatus { 4 | EMPTY, NEW, SUCCESS, DROP, TRANSMISSION_FAILURE, EXECUTION_FAILURE 5 | } 6 | -------------------------------------------------------------------------------- /edge-api/src/main/java/cn/edu/scut/mapper/EdgeNodeMapper.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut.mapper; 2 | 3 | import cn.edu.scut.bean.EdgeNode; 4 | import com.baomidou.mybatisplus.core.mapper.BaseMapper; 5 | import org.apache.ibatis.annotations.Mapper; 6 | 7 | @Mapper 8 | public interface EdgeNodeMapper extends BaseMapper { 9 | } 10 | -------------------------------------------------------------------------------- /edge-api/src/main/java/cn/edu/scut/mapper/LinkMapper.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut.mapper; 2 | 3 | import cn.edu.scut.bean.Link; 4 | import com.baomidou.mybatisplus.core.mapper.BaseMapper; 5 | import org.apache.ibatis.annotations.Mapper; 6 | 7 | @Mapper 8 | public interface LinkMapper extends BaseMapper { 9 | } 10 | -------------------------------------------------------------------------------- /edge-api/src/main/java/cn/edu/scut/mapper/TaskMapper.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut.mapper; 2 | 3 | import cn.edu.scut.bean.Task; 4 | import com.baomidou.mybatisplus.core.mapper.BaseMapper; 5 | import org.apache.ibatis.annotations.Mapper; 6 | 7 | @Mapper 8 | public interface TaskMapper extends BaseMapper { 9 | } 10 | -------------------------------------------------------------------------------- /edge-api/src/main/java/cn/edu/scut/service/EdgeNodeService.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut.service; 2 | 3 | import cn.edu.scut.bean.EdgeNode; 4 | import cn.edu.scut.mapper.EdgeNodeMapper; 5 | import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl; 6 | import org.springframework.stereotype.Service; 7 | 8 | @Service 9 | public class EdgeNodeService extends ServiceImpl { 10 | } 11 | -------------------------------------------------------------------------------- /edge-api/src/main/java/cn/edu/scut/service/LinkService.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut.service; 2 | 3 | import cn.edu.scut.bean.Link; 4 | import cn.edu.scut.mapper.LinkMapper; 5 | import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl; 6 | import org.springframework.stereotype.Service; 7 | 8 | @Service 9 | public class LinkService extends ServiceImpl { 10 | } 11 | -------------------------------------------------------------------------------- /edge-api/src/main/java/cn/edu/scut/service/TaskService.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut.service; 2 | 3 | import cn.edu.scut.bean.Task; 4 | import cn.edu.scut.mapper.TaskMapper; 5 | import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; 6 | import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl; 7 | import org.springframework.beans.factory.annotation.Autowired; 8 | import org.springframework.stereotype.Service; 9 | 10 | import java.util.Map; 11 | 12 | @Service 13 | public class TaskService extends ServiceImpl { 14 | 15 | @Autowired 16 | private TaskMapper taskMapper; 17 | 18 | public float getSuccessRate(Map map) { 19 | Long successTasks = taskMapper.selectCount(new QueryWrapper().eq("status", "SUCCESS").gt("id", map.get("startIndex"))); 20 | Long totalTasks = taskMapper.selectCount(new QueryWrapper().gt("id", map.get("startIndex"))); 21 | return Float.valueOf(successTasks) / Float.valueOf(totalTasks); 22 | } 23 | 24 | public double getSuccessRate() { 25 | Long successTasks = taskMapper.selectCount(new QueryWrapper().eq("status", "SUCCESS")); 26 | Long totalTasks = taskMapper.selectCount(new QueryWrapper()) - taskMapper.selectCount(new QueryWrapper().eq("status", "EMPTY")); 27 | return Double.valueOf(successTasks) / Double.valueOf(totalTasks); 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /edge-api/src/main/java/cn/edu/scut/util/ArrayUtils.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut.util; 2 | 3 | import java.util.Arrays; 4 | import java.util.List; 5 | import java.util.stream.Collectors; 6 | 7 | public class ArrayUtils { 8 | public static float[] toFloatArray(List list) { 9 | var res = new float[list.size()]; 10 | for (int i = 0; i < list.size(); i++) { 11 | res[i] = list.get(i); 12 | } 13 | return res; 14 | } 15 | 16 | public static double[] toDoubleArray(List list) { 17 | var res = new double[list.size()]; 18 | for (int i = 0; i < list.size(); i++) { 19 | res[i] = list.get(i); 20 | } 21 | return res; 22 | } 23 | 24 | public static String arrayToString(int[] array) { 25 | return Arrays.stream(array).mapToObj(Integer::toString).collect(Collectors.joining(",")); 26 | } 27 | 28 | public static double[] floatToDoubleArray(float[] x) { 29 | double[] ret = new double[x.length]; 30 | for (int i = 0; i < x.length; i++) { 31 | ret[i] = x[i]; 32 | } 33 | return ret; 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /edge-controller/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM lhc-ubuntu-20:v1.1 2 | LABEL maintainer=hongcai 3 | ENV PATH /root/lhc_dev/jdk-17.0.5/bin:${PATH} 4 | COPY target/edge-controller-1.0-SNAPSHOT.jar /root/app.jar 5 | WORKDIR /root 6 | ENTRYPOINT ["sh","-c","java -jar app.jar"] -------------------------------------------------------------------------------- /edge-controller/README.md: -------------------------------------------------------------------------------- 1 | mock the users to seed task to edge nodes. 2 | generate the configuration of task, edge node and link. 3 | start, stop the experiment. 4 | -------------------------------------------------------------------------------- /edge-controller/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 6 | edge-computing 7 | cn.edu.scut 8 | 1.0-SNAPSHOT 9 | 10 | 4.0.0 11 | 12 | edge-controller 13 | 14 | 15 | 17 16 | 17 17 | 18 | 19 | 20 | 21 | 22 | mysql 23 | mysql-connector-java 24 | 25 | 26 | com.alibaba 27 | druid-spring-boot-starter 28 | 29 | 30 | com.baomidou 31 | mybatis-plus-boot-starter 32 | 33 | 34 | 35 | 36 | 37 | com.alibaba.cloud 38 | spring-cloud-starter-alibaba-nacos-config 39 | 40 | 41 | org.springframework.cloud 42 | spring-cloud-starter-bootstrap 43 | 44 | 45 | 46 | com.alibaba.cloud 47 | spring-cloud-starter-alibaba-nacos-discovery 48 | 49 | 50 | 51 | 52 | org.springframework.boot 53 | spring-boot-starter-web 54 | 55 | 56 | 57 | org.springframework.boot 58 | spring-boot-starter-logging 59 | 60 | 61 | 62 | org.projectlombok 63 | lombok 64 | true 65 | 66 | 67 | 68 | 69 | org.springframework.boot 70 | spring-boot-starter-test 71 | test 72 | 73 | 74 | 75 | 76 | org.springframework.cloud 77 | spring-cloud-loadbalancer 78 | 79 | 80 | 81 | cn.edu.scut 82 | edge-api 83 | 1.0-SNAPSHOT 84 | 85 | 86 | 87 | org.apache.commons 88 | commons-math3 89 | 90 | 91 | cn.edu.scut 92 | edge-algorithm 93 | 1.0-SNAPSHOT 94 | test 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | org.springframework.boot 103 | spring-boot-maven-plugin 104 | 105 | true 106 | true 107 | 108 | 109 | 110 | 111 | repackage 112 | 113 | 114 | 115 | 116 | 117 | 118 | -------------------------------------------------------------------------------- /edge-controller/src/main/java/cn/edu/scut/ControllerApp.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut; 2 | 3 | import cn.edu.scut.utils.SpringBeanUtils; 4 | import org.springframework.boot.SpringApplication; 5 | import org.springframework.boot.autoconfigure.SpringBootApplication; 6 | import org.springframework.cloud.client.discovery.EnableDiscoveryClient; 7 | import org.springframework.context.ConfigurableApplicationContext; 8 | import org.springframework.scheduling.annotation.EnableAsync; 9 | 10 | @EnableDiscoveryClient 11 | @SpringBootApplication 12 | @EnableAsync 13 | public class ControllerApp { 14 | public static void main(String[] args) { 15 | ConfigurableApplicationContext context = SpringApplication.run(ControllerApp.class, args); 16 | SpringBeanUtils.setApplicationContext(context); 17 | } 18 | } -------------------------------------------------------------------------------- /edge-controller/src/main/java/cn/edu/scut/config/DataGenerationConfig.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut.config; 2 | 3 | import org.apache.commons.math3.distribution.UniformIntegerDistribution; 4 | import org.apache.commons.math3.distribution.UniformRealDistribution; 5 | import org.apache.commons.math3.random.JDKRandomGenerator; 6 | import org.apache.commons.math3.random.RandomGenerator; 7 | import org.springframework.beans.factory.annotation.Value; 8 | import org.springframework.context.annotation.Bean; 9 | import org.springframework.context.annotation.Configuration; 10 | 11 | @Configuration 12 | public class DataGenerationConfig { 13 | 14 | @Value("${edgeComputing.minExecutionFailureRate}") 15 | private Double minExecutionFailureRate; 16 | 17 | @Value("${edgeComputing.maxExecutionFailureRate}") 18 | private Double maxExecutionFailureRate; 19 | 20 | @Value("${edgeComputing.maxCpuCore}") 21 | private Integer maxCpuCore; 22 | 23 | @Value("${edgeComputing.minCpuCore}") 24 | private Integer minCpuCore; 25 | 26 | @Value("${edgeComputing.minTaskRate}") 27 | private Double minTaskRate; 28 | 29 | @Value("${edgeComputing.maxTaskRate}") 30 | private Double maxTaskRate; 31 | 32 | @Value("${edgeComputing.minTransmissionRate}") 33 | private Double minTransmissionRate; 34 | 35 | @Value("${edgeComputing.maxTransmissionRate}") 36 | private Double maxTransmissionRate; 37 | 38 | @Value("${edgeComputing.minTransmissionFailureRate}") 39 | private Double minTransmissionFailureRate; 40 | 41 | @Value("${edgeComputing.maxTransmissionFailureRate}") 42 | private Double maxTransmissionFailureRate; 43 | 44 | @Value("${edgeComputing.edgeNodeSeed}") 45 | private int edgeNodeSeed; 46 | 47 | @Value("${edgeComputing.taskSeed}") 48 | private int taskSeed; 49 | 50 | @Value("${edgeComputing.minTaskSize}") 51 | private int minTaskSize; 52 | 53 | @Value("${edgeComputing.maxTaskSize}") 54 | private int maxTaskSize; 55 | 56 | @Value("${edgeComputing.minTaskComplexity}") 57 | private int minTaskComplexity; 58 | 59 | @Value("${edgeComputing.maxTaskComplexity}") 60 | private int maxTaskComplexity; 61 | 62 | @Bean 63 | public RandomGenerator randomGenerator() { 64 | var random = new JDKRandomGenerator(); 65 | random.setSeed(edgeNodeSeed); 66 | return random; 67 | } 68 | 69 | @Bean 70 | public UniformRealDistribution executionFailureRandom(RandomGenerator randomGenerator) { 71 | return new UniformRealDistribution(randomGenerator, minExecutionFailureRate, maxExecutionFailureRate); 72 | } 73 | 74 | @Bean 75 | public UniformIntegerDistribution cpuCoreRandom(RandomGenerator randomGenerator) { 76 | return new UniformIntegerDistribution(randomGenerator, minCpuCore / 4, maxCpuCore / 4); 77 | } 78 | 79 | @Bean 80 | public UniformRealDistribution taskRateRandom(RandomGenerator randomGenerator) { 81 | return new UniformRealDistribution(randomGenerator, minTaskRate, maxTaskRate); 82 | } 83 | 84 | @Bean 85 | public UniformRealDistribution transmissionRateRandom(RandomGenerator randomGenerator) { 86 | return new UniformRealDistribution(randomGenerator, minTransmissionRate, maxTransmissionRate); 87 | } 88 | 89 | @Bean 90 | public UniformRealDistribution transmissionFailureRateRandom(RandomGenerator randomGenerator) { 91 | return new UniformRealDistribution(randomGenerator, minTransmissionFailureRate, maxTransmissionFailureRate); 92 | } 93 | 94 | @Bean 95 | public RandomGenerator taskRandomGenerator() { 96 | var random = new JDKRandomGenerator(); 97 | random.setSeed(taskSeed); 98 | return random; 99 | } 100 | 101 | @Bean 102 | public UniformIntegerDistribution taskSizeRandom(RandomGenerator taskRandomGenerator) { 103 | return new UniformIntegerDistribution(taskRandomGenerator, minTaskSize, maxTaskSize); 104 | } 105 | 106 | @Bean 107 | public UniformIntegerDistribution taskComplexityRandom(RandomGenerator taskRandomGenerator) { 108 | return new UniformIntegerDistribution(taskRandomGenerator, minTaskComplexity, maxTaskComplexity); 109 | } 110 | } 111 | -------------------------------------------------------------------------------- /edge-controller/src/main/java/cn/edu/scut/config/SpringConfig.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut.config; 2 | 3 | import org.springframework.cloud.client.loadbalancer.LoadBalanced; 4 | import org.springframework.context.annotation.Bean; 5 | import org.springframework.context.annotation.Configuration; 6 | import org.springframework.web.client.RestTemplate; 7 | 8 | @Configuration 9 | public class SpringConfig { 10 | 11 | @Bean 12 | @LoadBalanced 13 | RestTemplate restTemplate() { 14 | return new RestTemplate(); 15 | } 16 | 17 | } 18 | -------------------------------------------------------------------------------- /edge-controller/src/main/java/cn/edu/scut/controller/EdgeNodeController.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut.controller; 2 | 3 | import cn.edu.scut.service.EdgeConfigService; 4 | import cn.edu.scut.service.UserService; 5 | import lombok.extern.apachecommons.CommonsLog; 6 | import org.springframework.beans.factory.annotation.Autowired; 7 | import org.springframework.beans.factory.annotation.Value; 8 | import org.springframework.web.bind.annotation.GetMapping; 9 | import org.springframework.web.bind.annotation.RestController; 10 | import org.springframework.web.client.RestTemplate; 11 | 12 | @RestController 13 | @CommonsLog 14 | public class EdgeNodeController { 15 | @Autowired 16 | private RestTemplate restTemplate; 17 | 18 | @Autowired 19 | private EdgeConfigService edgeConfigService; 20 | 21 | @Autowired 22 | private UserService userService; 23 | 24 | @Value("${edgeComputing.edgeNodeNumber}") 25 | private int edgeNodeNumber; 26 | 27 | @GetMapping("/generate") 28 | public String generate() { 29 | log.info("generate all edge node configuration."); 30 | for (int i = 1; i <= edgeNodeNumber; i++) { 31 | String name = String.format("edge-node-%d", i); 32 | edgeConfigService.generateEdgeLink(name); 33 | } 34 | return "success\n"; 35 | } 36 | 37 | @GetMapping("/init") 38 | public String init() { 39 | log.info("init all edge node configuration."); 40 | for (int i = 1; i <= edgeNodeNumber; i++) { 41 | String name = String.format("edge-node-%d", i); 42 | log.info("init " + name); 43 | String url = String.format("http://%s/cmd/init", name); 44 | restTemplate.getForObject(url, String.class); 45 | } 46 | return "success\n"; 47 | } 48 | 49 | @GetMapping("/start") 50 | public String start() { 51 | log.info("start experiment."); 52 | userService.start(); 53 | return "success\n"; 54 | } 55 | 56 | @GetMapping("/stop") 57 | public String stop() { 58 | log.info("stop experiment."); 59 | userService.stop(); 60 | return "success\n"; 61 | } 62 | 63 | @GetMapping("/restart") 64 | public String restart() { 65 | log.info("restart experiment."); 66 | userService.restart(); 67 | return "success\n"; 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /edge-controller/src/main/java/cn/edu/scut/service/EdgeConfigService.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut.service; 2 | 3 | import cn.edu.scut.bean.*; 4 | import cn.edu.scut.util.ArrayUtils; 5 | import lombok.extern.slf4j.Slf4j; 6 | import org.apache.commons.math3.distribution.UniformIntegerDistribution; 7 | import org.apache.commons.math3.distribution.UniformRealDistribution; 8 | import org.springframework.beans.factory.annotation.Autowired; 9 | import org.springframework.beans.factory.annotation.Value; 10 | import org.springframework.cloud.context.config.annotation.RefreshScope; 11 | import org.springframework.stereotype.Service; 12 | 13 | @Service 14 | @Slf4j 15 | @RefreshScope 16 | public class EdgeConfigService { 17 | 18 | @Autowired 19 | private UniformRealDistribution executionFailureRandom; 20 | 21 | @Autowired 22 | private UniformIntegerDistribution cpuCoreRandom; 23 | 24 | @Autowired 25 | private UniformRealDistribution taskRateRandom; 26 | 27 | @Autowired 28 | private UniformRealDistribution transmissionRateRandom; 29 | 30 | @Autowired 31 | private UniformRealDistribution transmissionFailureRateRandom; 32 | 33 | @Autowired 34 | private UniformIntegerDistribution taskSizeRandom; 35 | 36 | @Autowired 37 | private UniformIntegerDistribution taskComplexityRandom; 38 | 39 | @Autowired 40 | private EdgeNodeService edgeNodeService; 41 | 42 | @Autowired 43 | private LinkService linkService; 44 | 45 | @Value("${edgeComputing.minTransmissionFailureRate}") 46 | Double minTransmissionFailureRate; 47 | 48 | @Value("${edgeComputing.maxTransmissionRate}") 49 | Double maxTransmissionRate; 50 | 51 | @Value("${edgeComputing.minTaskSize}") 52 | Long minTaskSize; 53 | 54 | @Value("${edgeComputing.maxTaskSize}") 55 | Long maxTaskSize; 56 | 57 | @Value("${edgeComputing.minTaskComplexity}") 58 | Long minTaskComplexity; 59 | 60 | @Value("${edgeComputing.maxTaskComplexity}") 61 | Long maxTaskComplexity; 62 | 63 | @Value("${edgeComputing.deadline}") 64 | Long deadline; 65 | 66 | @Value("${edgeComputing.edgeNodeNumber}") 67 | private int edgeNodeNumber; 68 | 69 | public void generateEdgeLink(String name) { 70 | EdgeNode edgeNode = new EdgeNode(); 71 | edgeNode.setName(name); 72 | edgeNode.setExecutionFailureRate(executionFailureRandom.sample()); 73 | // ! int -> long 74 | Integer core = cpuCoreRandom.sample() * 4; 75 | edgeNode.setCpuNum(core.longValue()); 76 | edgeNode.setTaskRate(taskRateRandom.sample()); 77 | edgeNodeService.save(edgeNode); 78 | 79 | for (int i = 1; i <= edgeNodeNumber; i++) { 80 | String service = String.format("edge-node-%d", i); 81 | Link link = new Link(); 82 | if (service.equals(name)) { 83 | link.setSource(name); 84 | link.setDestination(name); 85 | link.setTransmissionRate(maxTransmissionRate * Constants.Mega.value * Constants.Byte.value); 86 | link.setTransmissionFailureRate(minTransmissionFailureRate); 87 | } else { 88 | link.setSource(name); 89 | link.setDestination(service); 90 | link.setTransmissionRate(transmissionRateRandom.sample() * Constants.Mega.value * Constants.Byte.value); 91 | link.setTransmissionFailureRate(transmissionFailureRateRandom.sample()); 92 | } 93 | linkService.save(link); 94 | } 95 | } 96 | 97 | public Task generateTask(String name) { 98 | Task task = new Task(); 99 | task.setTaskSize(((long) taskSizeRandom.sample()) * StoreConstants.Kilo.value * StoreConstants.Byte.value); 100 | task.setTaskComplexity((long) taskComplexityRandom.sample()); 101 | task.setCpuCycle(task.getTaskComplexity() * task.getTaskSize()); 102 | task.setDeadline(deadline); 103 | task.setSource(name); 104 | task.setStatus(TaskStatus.NEW); 105 | var availAction = new int[edgeNodeNumber + 1]; 106 | for (int i = 0; i < edgeNodeNumber; i++) { 107 | availAction[i] = 1; 108 | } 109 | task.setAvailAction(ArrayUtils.arrayToString(availAction)); 110 | return task; 111 | } 112 | 113 | public Task generateEmptyTask(String name) { 114 | Task task = new Task(); 115 | task.setTaskSize(0L); 116 | task.setTaskComplexity(0L); 117 | task.setCpuCycle(0L); 118 | task.setDeadline(0L); 119 | task.setSource(name); 120 | task.setDestination("null"); 121 | task.setStatus(TaskStatus.EMPTY); 122 | var availAction = new int[edgeNodeNumber + 1]; 123 | availAction[edgeNodeNumber] = 1; 124 | task.setAvailAction(ArrayUtils.arrayToString(availAction)); 125 | task.setExecutionTime(0L); 126 | task.setExecutionWaitingTime(0L); 127 | task.setTransmissionTime(0L); 128 | task.setTransmissionWaitingTime(0L); 129 | return task; 130 | } 131 | } -------------------------------------------------------------------------------- /edge-controller/src/main/java/cn/edu/scut/service/UserService.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut.service; 2 | 3 | import cn.edu.scut.thread.UserRunnable; 4 | import cn.edu.scut.utils.SpringBeanUtils; 5 | import org.springframework.beans.factory.annotation.Value; 6 | import org.springframework.stereotype.Service; 7 | 8 | import java.util.concurrent.ScheduledThreadPoolExecutor; 9 | import java.util.concurrent.TimeUnit; 10 | 11 | @Service 12 | public class UserService { 13 | 14 | @Value("${edgeComputing.timeSlot}") 15 | private int timeSlot; 16 | 17 | private ScheduledThreadPoolExecutor threadPoolExecutor = new ScheduledThreadPoolExecutor(1); 18 | 19 | public void start() { 20 | UserRunnable userRunnable = SpringBeanUtils.applicationContext.getBean(UserRunnable.class); 21 | threadPoolExecutor.scheduleAtFixedRate(userRunnable, 0, timeSlot, TimeUnit.MILLISECONDS); 22 | } 23 | 24 | public void stop() { 25 | threadPoolExecutor.shutdown(); 26 | } 27 | 28 | public void restart() { 29 | threadPoolExecutor = new ScheduledThreadPoolExecutor(1); 30 | UserRunnable userRunnable = SpringBeanUtils.applicationContext.getBean(UserRunnable.class); 31 | userRunnable.setTimeSlot(0); 32 | threadPoolExecutor.scheduleAtFixedRate(userRunnable, 0, timeSlot, TimeUnit.MILLISECONDS); 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /edge-controller/src/main/java/cn/edu/scut/thread/UserRunnable.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut.thread; 2 | 3 | import cn.edu.scut.bean.EdgeNode; 4 | import cn.edu.scut.service.EdgeConfigService; 5 | import cn.edu.scut.service.EdgeNodeService; 6 | import cn.edu.scut.service.TaskService; 7 | import lombok.Setter; 8 | import lombok.extern.slf4j.Slf4j; 9 | import org.apache.commons.math3.random.RandomGenerator; 10 | import org.springframework.beans.factory.annotation.Autowired; 11 | import org.springframework.beans.factory.annotation.Value; 12 | import org.springframework.context.annotation.Scope; 13 | import org.springframework.stereotype.Component; 14 | import org.springframework.web.client.RestTemplate; 15 | 16 | import java.util.List; 17 | 18 | @Setter 19 | @Slf4j 20 | @Scope("prototype") 21 | @Component 22 | public class UserRunnable implements Runnable { 23 | 24 | @Autowired 25 | private RestTemplate restTemplate; 26 | 27 | @Autowired 28 | private RandomGenerator taskRandomGenerator; 29 | 30 | @Autowired 31 | private EdgeConfigService edgeConfigService; 32 | 33 | @Autowired 34 | private EdgeNodeService edgeNodeService; 35 | 36 | @Autowired 37 | private TaskService taskService; 38 | 39 | private int timeSlot = 0; 40 | 41 | @Value("${edgeComputing.episodeLimit}") 42 | private int episodeLimit; 43 | 44 | @Override 45 | public void run() { 46 | timeSlot += 1; 47 | if (timeSlot > episodeLimit + 1) { 48 | return; 49 | } 50 | try { 51 | List edgeNodes = edgeNodeService.list(); 52 | for (EdgeNode edgeNode : edgeNodes) { 53 | boolean flag = taskRandomGenerator.nextDouble() < edgeNode.getTaskRate(); 54 | if (flag) { 55 | var task = edgeConfigService.generateTask(edgeNode.getName()); 56 | task.setTimeSlot(timeSlot); 57 | taskService.save(task); 58 | String url = String.format("http://%s/user/task", edgeNode.getName()); 59 | restTemplate.postForObject(url, task, String.class); 60 | } else { 61 | var task = edgeConfigService.generateEmptyTask(edgeNode.getName()); 62 | task.setTimeSlot(timeSlot); 63 | taskService.save(task); 64 | } 65 | } 66 | } catch (Exception e) { 67 | e.printStackTrace(); 68 | } 69 | } 70 | } -------------------------------------------------------------------------------- /edge-controller/src/main/java/cn/edu/scut/utils/SpringBeanUtils.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut.utils; 2 | 3 | import org.springframework.context.ApplicationContext; 4 | 5 | public class SpringBeanUtils { 6 | public static ApplicationContext applicationContext; 7 | 8 | public static void setApplicationContext(ApplicationContext applicationContext) { 9 | SpringBeanUtils.applicationContext = applicationContext; 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /edge-controller/src/main/resources/META-INF/MANIFEST.MF: -------------------------------------------------------------------------------- 1 | Manifest-Version: 1.0 2 | Main-Class: cn.edu.scut.ControllerApp 3 | 4 | -------------------------------------------------------------------------------- /edge-controller/src/main/resources/bootstrap.yaml: -------------------------------------------------------------------------------- 1 | spring: 2 | application: 3 | name: edge-controller 4 | profiles: 5 | # active: edge-computing,mappo 6 | # active: edge-computing,masac 7 | # active: edge-computing,reliability-two-choice 8 | # active: edge-computing,reactive 9 | active: edge-computing,random 10 | # active: edge-computing,local 11 | cloud: 12 | nacos: 13 | discovery: 14 | server-addr: 222.201.187.50:30848 15 | watch-delay: 3000 16 | config: 17 | server-addr: 222.201.187.50:30848 18 | prefix: application 19 | file-extension: yaml -------------------------------------------------------------------------------- /edge-experiment/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM lhc-ubuntu-20:v1.1 2 | LABEL maintainer=hongcai 3 | ENV PATH /root/lhc_dev/jdk-17.0.5/bin:${PATH} 4 | COPY target/edge-experiment-1.0-SNAPSHOT.jar /root/app.jar 5 | WORKDIR /root 6 | ENTRYPOINT ["sh","-c","java -jar app.jar"] -------------------------------------------------------------------------------- /edge-experiment/README.md: -------------------------------------------------------------------------------- 1 | run the experiment. 2 | 3 | - online heuristic. 4 | - online DRL. 5 | - offline DRL. 6 | - offline to online DRL. -------------------------------------------------------------------------------- /edge-experiment/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 4.0.0 6 | 7 | cn.edu.scut 8 | edge-computing 9 | 1.0-SNAPSHOT 10 | 11 | 12 | edge-experiment 13 | 14 | 15 | 17 16 | 17 17 | UTF-8 18 | 19 | 20 | 21 | 22 | 23 | mysql 24 | mysql-connector-java 25 | 26 | 27 | com.alibaba 28 | druid-spring-boot-starter 29 | 30 | 31 | com.baomidou 32 | mybatis-plus-boot-starter 33 | 34 | 35 | 36 | cn.edu.scut 37 | edge-algorithm 38 | 1.0-SNAPSHOT 39 | 40 | 41 | 42 | com.alibaba.cloud 43 | spring-cloud-starter-alibaba-nacos-discovery 44 | 45 | 46 | com.alibaba.cloud 47 | spring-cloud-starter-alibaba-nacos-config 48 | 49 | 50 | org.springframework.cloud 51 | spring-cloud-starter-bootstrap 52 | 53 | 54 | 55 | 56 | org.springframework.boot 57 | spring-boot-starter-web 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | org.springframework.boot 75 | spring-boot-starter-logging 76 | 77 | 78 | 79 | 80 | org.projectlombok 81 | lombok 82 | true 83 | 84 | 85 | 86 | 87 | org.springframework.boot 88 | spring-boot-starter-test 89 | test 90 | 91 | 92 | 93 | 94 | org.springframework.cloud 95 | spring-cloud-loadbalancer 96 | 97 | 98 | 99 | 100 | cn.edu.scut 101 | edge-api 102 | 1.0-SNAPSHOT 103 | 104 | 105 | 106 | 107 | org.apache.commons 108 | commons-math3 109 | 110 | 111 | 112 | 113 | org.apache.hadoop 114 | hadoop-common 115 | 116 | 117 | org.slf4j 118 | slf4j-api 119 | 120 | 121 | org.slf4j 122 | slf4j-reload4j 123 | 124 | 125 | 126 | 127 | org.apache.hadoop 128 | hadoop-hdfs 129 | 130 | 131 | org.apache.hadoop 132 | hadoop-hdfs-client 133 | 134 | 135 | 136 | 137 | tech.tablesaw 138 | tablesaw-core 139 | LATEST 140 | 141 | 142 | tech.tablesaw 143 | tablesaw-jsplot 144 | 0.43.1 145 | 146 | 147 | 148 | 149 | 150 | 151 | org.springframework.boot 152 | spring-boot-maven-plugin 153 | 154 | true 155 | true 156 | 157 | 158 | 159 | 160 | repackage 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | -------------------------------------------------------------------------------- /edge-experiment/src/main/java/cn/edu/scut/ExperimentApp.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut; 2 | 3 | import cn.edu.scut.runner.*; 4 | import cn.edu.scut.util.SpringBeanUtils; 5 | import org.springframework.boot.SpringApplication; 6 | import org.springframework.boot.autoconfigure.SpringBootApplication; 7 | 8 | @SpringBootApplication 9 | public class ExperimentApp { 10 | public static void main(String[] args) { 11 | var context = SpringApplication.run(ExperimentApp.class, args); 12 | SpringBeanUtils.setApplicationContext(context); 13 | var env = context.getEnvironment(); 14 | var runnerType = env.getProperty("edgeComputing.runner", String.class); 15 | assert runnerType != null; 16 | var runner = switch (runnerType) { 17 | case "rl-online" -> context.getBean(OnlineMARLTrainingRunner.class); 18 | case "rl-offline" -> context.getBean(OfflineMARLTrainingRunner.class); 19 | case "rl-test" -> context.getBean(OnlineMARLTestingRunner.class); 20 | case "heuristic" -> context.getBean(OnlineHeuristicTestRunner.class); 21 | case "heuristic-data" -> context.getBean(OnlineHeuristicDataRunner.class); 22 | default -> throw new RuntimeException("error in runner type."); 23 | }; 24 | // waiting for edge-nodes and edge-controller start. 25 | try { 26 | Thread.sleep(5000); 27 | } catch (InterruptedException e) { 28 | throw new RuntimeException(e); 29 | } 30 | runner.run(); 31 | } 32 | } -------------------------------------------------------------------------------- /edge-experiment/src/main/java/cn/edu/scut/config/HadoopConfig.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut.config; 2 | 3 | import org.apache.hadoop.fs.FileSystem; 4 | import org.springframework.beans.factory.annotation.Value; 5 | import org.springframework.context.annotation.Bean; 6 | import org.springframework.context.annotation.Configuration; 7 | 8 | import java.io.IOException; 9 | import java.net.URI; 10 | import java.net.URISyntaxException; 11 | 12 | @Configuration 13 | public class HadoopConfig { 14 | 15 | @Value("${hadoop.hdfs.url}") 16 | private String hdfsUrl; 17 | 18 | @Bean 19 | public FileSystem fileSystem() { 20 | URI uri; 21 | try { 22 | uri = new URI(hdfsUrl); 23 | } catch (URISyntaxException e) { 24 | throw new RuntimeException(e); 25 | } 26 | var configuration = new org.apache.hadoop.conf.Configuration(); 27 | String user = "hongcai"; 28 | FileSystem fileSystem; 29 | try { 30 | fileSystem = FileSystem.get(uri, configuration, user); 31 | } catch (IOException | InterruptedException e) { 32 | throw new RuntimeException(e); 33 | } 34 | return fileSystem; 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /edge-experiment/src/main/java/cn/edu/scut/config/SpringConfig.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut.config; 2 | 3 | import org.springframework.beans.factory.annotation.Value; 4 | import org.springframework.cloud.client.loadbalancer.LoadBalanced; 5 | import org.springframework.context.annotation.Bean; 6 | import org.springframework.context.annotation.Configuration; 7 | import org.springframework.web.client.RestTemplate; 8 | 9 | import java.util.Random; 10 | 11 | @Configuration 12 | public class SpringConfig { 13 | @Value("${edgeComputing.seed}") 14 | Integer seed; 15 | 16 | @Bean 17 | @LoadBalanced 18 | RestTemplate restTemplate() { 19 | return new RestTemplate(); 20 | } 21 | 22 | @Bean 23 | public Random schedulerRandom() { 24 | return new Random(seed); 25 | } 26 | } -------------------------------------------------------------------------------- /edge-experiment/src/main/java/cn/edu/scut/controller/ModelController.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut.controller; 2 | 3 | import cn.edu.scut.agent.IMAAgent; 4 | import cn.edu.scut.service.TransitionService; 5 | import lombok.extern.slf4j.Slf4j; 6 | import org.springframework.beans.factory.annotation.Autowired; 7 | import org.springframework.web.bind.annotation.GetMapping; 8 | import org.springframework.web.bind.annotation.PathVariable; 9 | import org.springframework.web.bind.annotation.RestController; 10 | 11 | @RestController 12 | @Slf4j 13 | public class ModelController { 14 | 15 | 16 | @Autowired(required = false) 17 | private IMAAgent agent; 18 | 19 | @Autowired 20 | private TransitionService transitionService; 21 | 22 | @GetMapping("/action/{taskId}") 23 | public Integer action(@PathVariable("taskId") Long taskId) { 24 | var availAction = transitionService.getAvailAction(taskId); 25 | var state = transitionService.getState(taskId); 26 | int i = agent.selectAction(state, availAction, false) + 1; 27 | return i; 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /edge-experiment/src/main/java/cn/edu/scut/runner/IRunner.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut.runner; 2 | 3 | public interface IRunner { 4 | void run(); 5 | } 6 | -------------------------------------------------------------------------------- /edge-experiment/src/main/java/cn/edu/scut/runner/OfflineMARLTrainingRunner.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut.runner; 2 | 3 | import cn.edu.scut.agent.IMAAgent; 4 | import cn.edu.scut.agent.MABuffer; 5 | import cn.edu.scut.service.*; 6 | import cn.edu.scut.util.DateTimeUtils; 7 | import lombok.extern.slf4j.Slf4j; 8 | import org.springframework.beans.factory.annotation.Autowired; 9 | import org.springframework.beans.factory.annotation.Value; 10 | import org.springframework.context.annotation.Lazy; 11 | import org.springframework.stereotype.Service; 12 | 13 | import java.util.ArrayList; 14 | 15 | @Lazy 16 | @Service 17 | @Slf4j 18 | public class OfflineMARLTrainingRunner implements IRunner { 19 | 20 | @Value("${rl.name}") 21 | private String rlName; 22 | 23 | @Autowired 24 | private IMAAgent agent; 25 | 26 | @Autowired 27 | private TaskService taskService; 28 | 29 | @Autowired 30 | private LinkService linkService; 31 | 32 | @Autowired 33 | private EdgeNodeService edgeNodeService; 34 | 35 | @Autowired 36 | private PlotService plotService; 37 | 38 | @Autowired 39 | private RunnerService runnerService; 40 | 41 | @Autowired 42 | private MABuffer buffer; 43 | 44 | @Value("${rl.test-frequency}") 45 | private int testFrequency; 46 | 47 | @Value("${rl.training-time}") 48 | private int trainingTime; 49 | 50 | @Value("${rl.buffer-path}") 51 | private String bufferPath; 52 | 53 | @Value("${edgeComputing.flag}") 54 | private String flag; 55 | 56 | public void run() { 57 | log.info("========================"); 58 | log.info("run rl-offline runner."); 59 | log.info("========================"); 60 | 61 | // remove data 62 | linkService.remove(null); 63 | edgeNodeService.remove(null); 64 | taskService.remove(null); 65 | // init edge node and link 66 | runnerService.init(); 67 | buffer.loadHdfs(bufferPath); 68 | 69 | 70 | log.info("store communication information start!"); 71 | runnerService.run(); 72 | taskService.remove(null); 73 | log.info("store communication information end!"); 74 | 75 | var dateFlag = DateTimeUtils.getFlag(); 76 | var flag = dateFlag + "@" + rlName + "@" + this.flag; 77 | 78 | var episodes = new ArrayList(); 79 | var successRates = new ArrayList(); 80 | for (int i = 0; i <= trainingTime; i++) { 81 | if (i % testFrequency == 0) { 82 | // update model 83 | agent.saveHdfsModel(flag); 84 | runnerService.updateModelByHdfs(flag); 85 | // performance 86 | double successRate = runnerService.test(); 87 | log.info("train time: {}, test success rate: {}", i, successRate); 88 | episodes.add((double) i); 89 | successRates.add(successRate); 90 | } 91 | agent.train(); 92 | } 93 | plotService.plot(episodes, successRates, flag); 94 | } 95 | } 96 | -------------------------------------------------------------------------------- /edge-experiment/src/main/java/cn/edu/scut/runner/OnlineHeuristicDataRunner.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut.runner; 2 | 3 | import cn.edu.scut.agent.MABuffer; 4 | import cn.edu.scut.service.*; 5 | import cn.edu.scut.util.DateTimeUtils; 6 | import lombok.extern.slf4j.Slf4j; 7 | import org.springframework.beans.factory.annotation.Autowired; 8 | import org.springframework.beans.factory.annotation.Value; 9 | import org.springframework.context.annotation.Lazy; 10 | import org.springframework.stereotype.Service; 11 | 12 | import java.util.ArrayList; 13 | import java.util.stream.Collectors; 14 | 15 | @Service 16 | @Slf4j 17 | @Lazy 18 | public class OnlineHeuristicDataRunner implements IRunner { 19 | @Autowired 20 | private TaskService taskService; 21 | 22 | @Autowired 23 | private LinkService linkService; 24 | 25 | @Autowired 26 | private EdgeNodeService edgeNodeService; 27 | 28 | @Autowired 29 | private PlotService plotService; 30 | 31 | @Autowired 32 | private RunnerService runnerService; 33 | 34 | @Value("${edgeComputing.episodeNumber}") 35 | private int episodeNumber; 36 | 37 | @Value("${heuristic.name}") 38 | private String name; 39 | 40 | @Value("${edgeComputing.flag}") 41 | private String flag; 42 | 43 | @Autowired 44 | private MABuffer buffer; 45 | 46 | public void run() { 47 | log.info("============================="); 48 | log.info("run heuristic-data runner"); 49 | log.info("============================="); 50 | 51 | linkService.remove(null); 52 | edgeNodeService.remove(null); 53 | taskService.remove(null); 54 | runnerService.init(); 55 | 56 | log.info("store communication information start!"); 57 | runnerService.run(); 58 | taskService.remove(null); 59 | log.info("store communication information end!"); 60 | 61 | var dateFlag = DateTimeUtils.getFlag(); 62 | var flag = dateFlag + "@" + name + "@" + this.flag; 63 | 64 | var episodes = new ArrayList(); 65 | var successRates = new ArrayList(); 66 | for (int currentEpisode = 1; currentEpisode <= episodeNumber; currentEpisode++) { 67 | runnerService.run(); 68 | double successRate = taskService.getSuccessRate(); 69 | log.info("running time: {}, success rate: {}", currentEpisode, successRate); 70 | episodes.add((double) currentEpisode); 71 | successRates.add(successRate); 72 | runnerService.addData(); 73 | taskService.remove(null); 74 | } 75 | plotService.plot(episodes, successRates, flag); 76 | buffer.saveHdfs(flag); 77 | var averageSuccessRate = successRates.stream().collect(Collectors.averagingDouble(x -> x)); 78 | log.info("average success rate: {}", averageSuccessRate); 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /edge-experiment/src/main/java/cn/edu/scut/runner/OnlineHeuristicTestRunner.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut.runner; 2 | 3 | import cn.edu.scut.service.*; 4 | import lombok.extern.slf4j.Slf4j; 5 | import org.springframework.beans.factory.annotation.Autowired; 6 | import org.springframework.beans.factory.annotation.Value; 7 | import org.springframework.context.annotation.Lazy; 8 | import org.springframework.stereotype.Service; 9 | 10 | @Service 11 | @Slf4j 12 | @Lazy 13 | public class OnlineHeuristicTestRunner implements IRunner { 14 | @Autowired 15 | private TaskService taskService; 16 | 17 | @Autowired 18 | private LinkService linkService; 19 | 20 | @Autowired 21 | private EdgeNodeService edgeNodeService; 22 | 23 | @Autowired 24 | private PlotService plotService; 25 | 26 | @Autowired 27 | private RunnerService runnerService; 28 | 29 | @Value("${edgeComputing.episodeNumber}") 30 | private int episodeNumber; 31 | 32 | @Value("${heuristic.name}") 33 | private String name; 34 | 35 | @Value("${edgeComputing.flag}") 36 | private String flag; 37 | 38 | @Value("${edgeComputing.taskSeed}") 39 | private int taskSeed; 40 | 41 | public void run() { 42 | log.info(" task seed : {}", taskSeed); 43 | log.info("============================="); 44 | log.info("run heuristic runner."); 45 | log.info("============================="); 46 | linkService.remove(null); 47 | edgeNodeService.remove(null); 48 | taskService.remove(null); 49 | runnerService.init(); 50 | 51 | log.info("store communication information start!"); 52 | runnerService.run(); 53 | taskService.remove(null); 54 | log.info("store communication information end!"); 55 | 56 | runnerService.test(); 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /edge-experiment/src/main/java/cn/edu/scut/runner/OnlineMARLTestingRunner.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut.runner; 2 | 3 | import cn.edu.scut.service.EdgeNodeService; 4 | import cn.edu.scut.service.LinkService; 5 | import cn.edu.scut.service.RunnerService; 6 | import cn.edu.scut.service.TaskService; 7 | import lombok.extern.slf4j.Slf4j; 8 | import org.springframework.beans.factory.annotation.Autowired; 9 | import org.springframework.beans.factory.annotation.Value; 10 | import org.springframework.context.annotation.Lazy; 11 | import org.springframework.stereotype.Service; 12 | 13 | 14 | @Service 15 | @Lazy 16 | @Slf4j 17 | public class OnlineMARLTestingRunner implements IRunner { 18 | 19 | @Autowired 20 | private TaskService taskService; 21 | 22 | @Autowired 23 | private LinkService linkService; 24 | 25 | @Autowired 26 | private EdgeNodeService edgeNodeService; 27 | 28 | @Autowired 29 | private RunnerService runnerService; 30 | 31 | @Value("${rl.model-flag}") 32 | private String modelFlag; 33 | 34 | @Override 35 | public void run() { 36 | log.info("========================"); 37 | log.info("run rl-test runner!"); 38 | log.info("========================"); 39 | 40 | // remove data 41 | linkService.remove(null); 42 | edgeNodeService.remove(null); 43 | taskService.remove(null); 44 | // init edge node and link 45 | runnerService.init(); 46 | // update model 47 | runnerService.updateModelByHdfs(modelFlag); 48 | 49 | log.info("store communication information start!"); 50 | runnerService.run(); 51 | taskService.remove(null); 52 | log.info("store communication information end!"); 53 | 54 | runnerService.test(); 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /edge-experiment/src/main/java/cn/edu/scut/runner/OnlineMARLTrainingRunner.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut.runner; 2 | 3 | import cn.edu.scut.agent.IMAAgent; 4 | import cn.edu.scut.service.*; 5 | import cn.edu.scut.util.DateTimeUtils; 6 | import lombok.extern.slf4j.Slf4j; 7 | import org.springframework.beans.factory.annotation.Autowired; 8 | import org.springframework.beans.factory.annotation.Value; 9 | import org.springframework.context.annotation.Lazy; 10 | import org.springframework.stereotype.Component; 11 | 12 | import java.util.ArrayList; 13 | 14 | /** 15 | * online 16 | * offline-to-online 17 | */ 18 | @Component 19 | @Slf4j 20 | @Lazy 21 | public class OnlineMARLTrainingRunner implements IRunner { 22 | 23 | @Value("${edgeComputing.episodeNumber}") 24 | private int episodeNumber; 25 | 26 | @Value("${rl.name}") 27 | private String rlName; 28 | 29 | @Autowired 30 | private IMAAgent agent; 31 | 32 | @Autowired 33 | private TaskService taskService; 34 | 35 | @Autowired 36 | private LinkService linkService; 37 | 38 | @Autowired 39 | private EdgeNodeService edgeNodeService; 40 | 41 | @Autowired 42 | private PlotService plotService; 43 | 44 | @Autowired 45 | private RunnerService runnerService; 46 | 47 | @Value("${rl.use-trained-model}") 48 | private boolean useTrainedModel; 49 | 50 | @Value("${rl.model-flag}") 51 | private String modelFlag; 52 | 53 | @Value("${edgeComputing.flag}") 54 | private String flag; 55 | 56 | @Value("${edgeComputing.testFrequency}") 57 | private int testFrequency; 58 | 59 | public void run() { 60 | log.info("========================"); 61 | log.info("run rl-online runner!"); 62 | log.info("========================"); 63 | // remove data 64 | linkService.remove(null); 65 | edgeNodeService.remove(null); 66 | taskService.remove(null); 67 | // init edge node and link 68 | runnerService.init(); 69 | 70 | // offline to online 71 | if (useTrainedModel) { 72 | agent.loadHdfsModel(modelFlag); 73 | log.info("load model from hdfs."); 74 | } 75 | var dateFlag = DateTimeUtils.getFlag(); 76 | var flag = dateFlag + "@" + rlName + "@" + this.flag; 77 | 78 | log.info("store communication information start!"); 79 | runnerService.run(); 80 | taskService.remove(null); 81 | log.info("store communication information end!"); 82 | 83 | var episodes = new ArrayList(); 84 | var successRates = new ArrayList(); 85 | for (int currentEpisode = 0; currentEpisode <= episodeNumber; currentEpisode++) { 86 | // update model 87 | agent.saveHdfsModel(flag); 88 | runnerService.updateModelByHdfs(flag); 89 | // performance 90 | if (currentEpisode % testFrequency == 0) { 91 | log.info("start to test."); 92 | var successRate = runnerService.test(); 93 | episodes.add((double) currentEpisode); 94 | successRates.add(successRate); 95 | log.info("end test."); 96 | } 97 | // run episode 98 | runnerService.run(); 99 | var successRate = taskService.getSuccessRate(); 100 | log.info("training episode {}, success rate {}", currentEpisode, successRate); 101 | // process data 102 | runnerService.addData(); 103 | // training 104 | agent.train(); 105 | // remove data 106 | taskService.remove(null); 107 | } 108 | agent.saveHdfsModel(flag); 109 | // plot 110 | plotService.plot(episodes, successRates, flag); 111 | } 112 | } 113 | -------------------------------------------------------------------------------- /edge-experiment/src/main/java/cn/edu/scut/service/PlotService.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut.service; 2 | 3 | import cn.edu.scut.util.ArrayUtils; 4 | import cn.edu.scut.util.PlotUtils; 5 | import lombok.extern.slf4j.Slf4j; 6 | import org.apache.hadoop.fs.FileSystem; 7 | import org.apache.hadoop.fs.Path; 8 | import org.springframework.beans.factory.annotation.Autowired; 9 | import org.springframework.beans.factory.annotation.Value; 10 | import org.springframework.stereotype.Service; 11 | import tech.tablesaw.plotly.Plot; 12 | 13 | import java.io.IOException; 14 | import java.io.UncheckedIOException; 15 | import java.nio.file.Files; 16 | import java.nio.file.Paths; 17 | import java.util.ArrayList; 18 | 19 | @Service 20 | @Slf4j 21 | public class PlotService { 22 | 23 | @Autowired 24 | private FileSystem fileSystem; 25 | 26 | @Value("${hadoop.hdfs.url}") 27 | private String hdfsUrl; 28 | 29 | public void plot(ArrayList episodes, ArrayList successRates, String flag) { 30 | var episodes_ = ArrayUtils.toDoubleArray(episodes); 31 | var successRates_ = ArrayUtils.toDoubleArray(successRates); 32 | var figure = PlotUtils.plot(new double[][]{episodes_}, new double[][]{successRates_}, new String[]{"rl"}, "episode", "success rate"); 33 | var path = Paths.get("results", "figure", flag + ".html"); 34 | try { 35 | try { 36 | Files.createDirectories(path.getParent()); 37 | } catch (IOException var2) { 38 | throw new UncheckedIOException(var2); 39 | } 40 | var file = path.toFile(); 41 | Plot.show(figure, file); 42 | } catch (Exception e) { 43 | log.info("browser not support!"); 44 | } 45 | try { 46 | fileSystem.copyFromLocalFile(true, true, new Path(path.toString()), new Path(hdfsUrl + "/" + path)); 47 | } catch (IOException e) { 48 | throw new RuntimeException(e); 49 | } 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /edge-experiment/src/main/java/cn/edu/scut/service/RunnerService.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut.service; 2 | 3 | 4 | import cn.edu.scut.agent.MABuffer; 5 | import cn.edu.scut.bean.Task; 6 | import cn.edu.scut.util.MathUtils; 7 | import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; 8 | import lombok.extern.slf4j.Slf4j; 9 | import org.springframework.beans.factory.annotation.Autowired; 10 | import org.springframework.beans.factory.annotation.Value; 11 | import org.springframework.context.annotation.Lazy; 12 | import org.springframework.core.io.FileSystemResource; 13 | import org.springframework.http.HttpEntity; 14 | import org.springframework.http.HttpHeaders; 15 | import org.springframework.http.MediaType; 16 | import org.springframework.http.client.MultipartBodyBuilder; 17 | import org.springframework.stereotype.Service; 18 | import org.springframework.web.client.RestTemplate; 19 | 20 | import java.util.ArrayList; 21 | import java.util.Map; 22 | 23 | @Service 24 | @Slf4j 25 | public class RunnerService { 26 | 27 | @Autowired 28 | private RestTemplate restTemplate; 29 | 30 | @Autowired 31 | private TaskService taskService; 32 | 33 | @Value("${edgeComputing.edgeNodeNumber}") 34 | private int agentNumber; 35 | 36 | @Value("${edgeComputing.episodeLimit}") 37 | private int episodeLimit; 38 | 39 | @Value("${edgeComputing.timeSlot}") 40 | private int timeSlotLen; 41 | 42 | // heuristic-test 43 | @Value("${rl.state-shape:0}") 44 | private int stateShape; 45 | 46 | @Value("${edgeComputing.testNumber}") 47 | private int testNumber; 48 | 49 | // heuristic-test 50 | @Lazy 51 | @Autowired(required = false) 52 | private MABuffer buffer; 53 | 54 | @Autowired 55 | private TransitionService transitionService; 56 | 57 | public void init() { 58 | var controllerUrl = "http://edge-controller"; 59 | var generateMessage = restTemplate.getForObject(controllerUrl + "/generate", String.class); 60 | log.info("generate edge nodes configuration: {}", generateMessage); 61 | var initMessage = restTemplate.getForObject(controllerUrl + "/init", String.class); 62 | log.info("init edge nodes: {}", initMessage); 63 | } 64 | 65 | public void run() { 66 | var controllerUrl = "http://edge-controller"; 67 | var restartMessage = restTemplate.getForObject(controllerUrl + "/restart", String.class); 68 | log.info("restart experiment : {}", restartMessage); 69 | while (true) { 70 | try { 71 | Thread.sleep(1000); 72 | } catch (InterruptedException e) { 73 | log.error("{}", e.getMessage()); 74 | } 75 | var count = taskService.count(new QueryWrapper().ne("status", "NEW")); 76 | if (count >= (episodeLimit + 1) * agentNumber) { 77 | break; 78 | } 79 | } 80 | var stopMessage = restTemplate.getForObject(controllerUrl + "/stop", String.class); 81 | log.info("stop experiment: {}", stopMessage); 82 | } 83 | 84 | public double test() { 85 | var list = new ArrayList(); 86 | for (int i = 1; i <= testNumber; i++) { 87 | run(); 88 | var successRate = taskService.getSuccessRate(); 89 | log.info("test {}, success rate : {}", i, successRate); 90 | list.add(successRate); 91 | taskService.remove(null); 92 | } 93 | var avg = MathUtils.avg(list); 94 | var std = MathUtils.std(list); 95 | log.info("avg success rate: {}", avg); 96 | log.info("std success rate: {}", std); 97 | return avg; 98 | } 99 | 100 | public void updateModelByHdfs(String flag) { 101 | for (int i = 1; i < agentNumber; i++) { 102 | String edgeNodeId = String.format("edge-node-%d", i); 103 | restTemplate.getForObject("http://" + edgeNodeId + "/updateParam/{flag}", String.class, Map.of("flag", flag)); 104 | } 105 | } 106 | 107 | public void updateModelByFileTransfer(String flag) { 108 | for (int i = 1; i < agentNumber; i++) { 109 | String edgeNodeId = String.format("edge-node-%d", i); 110 | var headers = new HttpHeaders(); 111 | headers.setContentType(MediaType.MULTIPART_FORM_DATA); 112 | var multipartBodyBuilder = new MultipartBodyBuilder(); 113 | var file1 = new FileSystemResource("results/model/" + flag + "/actor.param"); 114 | multipartBodyBuilder.part("actor", file1, MediaType.TEXT_PLAIN); 115 | var multipartBody = multipartBodyBuilder.build(); 116 | var httpEntity = new HttpEntity<>(multipartBody, headers); 117 | try { 118 | restTemplate.postForObject("http://" + edgeNodeId + "/updateModel", httpEntity, String.class); 119 | } catch (Exception e) { 120 | log.error(e.getMessage()); 121 | throw new RuntimeException(e); 122 | } 123 | } 124 | } 125 | 126 | public void addData() { 127 | var teamRewards = new float[episodeLimit * 2]; 128 | var states = new float[episodeLimit][agentNumber][stateShape]; 129 | var actions = new int[episodeLimit][agentNumber][1]; 130 | var availActions = new int[episodeLimit][agentNumber][1]; 131 | var rewards = new float[episodeLimit][agentNumber][1]; 132 | var nextStates = new float[episodeLimit][agentNumber][stateShape]; 133 | 134 | for (int i = 1; i <= episodeLimit; i++) { 135 | for (int j = 1; j <= agentNumber; j++) { 136 | var source = String.format("edge-node-%d", j); 137 | var task = taskService.getOne(new QueryWrapper().eq("source", source).eq("time_slot", i)); 138 | var nextTask = taskService.getOne(new QueryWrapper().eq("source", source).eq("time_slot", i + 1)); 139 | var state = transitionService.getState(task.getId()); 140 | states[i - 1][j - 1] = state; 141 | var action = transitionService.getAction(task.getId()); 142 | actions[i - 1][j - 1] = new int[]{action}; 143 | var nextState = transitionService.getState(nextTask.getId()); 144 | nextStates[i - 1][j - 1] = nextState; 145 | var availAction = transitionService.getAvailAction(task.getId()); 146 | availActions[i - 1][j - 1] = availAction; 147 | var reward = transitionService.getReward(task.getId()); 148 | long totalTime = task.getTransmissionWaitingTime() + task.getTransmissionTime() + task.getExecutionWaitingTime() + task.getExecutionTime(); 149 | var endTimeSlot = task.getTimeSlot() + totalTime / timeSlotLen; 150 | int timeSlotIndex = (int) endTimeSlot - 1; 151 | teamRewards[timeSlotIndex] += reward; 152 | } 153 | } 154 | for (int i = 1; i <= episodeLimit; i++) { 155 | for (int j = 1; j <= agentNumber; j++) { 156 | rewards[i - 1][j - 1] = new float[]{teamRewards[i - 1]}; 157 | } 158 | } 159 | for (int i = 1; i <= episodeLimit; i++) { 160 | buffer.insert(states[i - 1], actions[i - 1], availActions[i - 1], rewards[i - 1], nextStates[i - 1]); 161 | } 162 | } 163 | } 164 | -------------------------------------------------------------------------------- /edge-experiment/src/main/java/cn/edu/scut/util/DateTimeUtils.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut.util; 2 | 3 | import java.time.LocalDateTime; 4 | import java.time.format.DateTimeFormatter; 5 | 6 | public class DateTimeUtils { 7 | 8 | public static String getFlag() { 9 | var dateTime = LocalDateTime.now(); 10 | var dateTime2 = dateTime.plusHours(8); 11 | var formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd-hh_mm_ss"); 12 | return formatter.format(dateTime2); 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /edge-experiment/src/main/java/cn/edu/scut/util/MathUtils.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut.util; 2 | 3 | import java.util.ArrayList; 4 | 5 | public class MathUtils { 6 | public static double avg(ArrayList data) { 7 | double sum = 0; 8 | for (Double val : data) { 9 | sum += val; 10 | } 11 | return sum / data.size(); 12 | } 13 | 14 | public static double std(ArrayList data) { 15 | double res = 0; 16 | double avg = avg(data); 17 | for (Double val : data) { 18 | res += Math.pow((val - avg), 2); 19 | } 20 | res /= data.size() - 1; 21 | res = Math.sqrt(res); 22 | return res; 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /edge-experiment/src/main/java/cn/edu/scut/util/SpringBeanUtils.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut.util; 2 | 3 | import org.springframework.context.ApplicationContext; 4 | 5 | public class SpringBeanUtils { 6 | public static ApplicationContext applicationContext; 7 | 8 | public static void setApplicationContext(ApplicationContext applicationContext) { 9 | SpringBeanUtils.applicationContext = applicationContext; 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /edge-experiment/src/main/resources/bootstrap.yaml: -------------------------------------------------------------------------------- 1 | spring: 2 | application: 3 | name: edge-experiment 4 | profiles: 5 | # active: edge-computing,mappo 6 | # active: edge-computing,masac 7 | # active: edge-computing,reliability-two-choice 8 | # active: edge-computing,reactive 9 | active: edge-computing,random 10 | # active: edge-computing,local 11 | cloud: 12 | nacos: 13 | discovery: 14 | server-addr: 222.201.187.50:30848 15 | watch-delay: 3000 16 | config: 17 | server-addr: 222.201.187.50:30848 18 | prefix: application 19 | file-extension: yaml 20 | server: 21 | port: 9001 -------------------------------------------------------------------------------- /edge-experiment/test-results/buffer/test-buffer.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhc0512/edge-computing/d13e6f16266b5b98d69c8e540f61ac60f6db6e31/edge-experiment/test-results/buffer/test-buffer.txt -------------------------------------------------------------------------------- /edge-node/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM lhc-ubuntu-20:v1.1 2 | LABEL maintainer=hongcai 3 | ENV PATH /root/lhc_dev/jdk-17.0.5/bin:${PATH} 4 | COPY target/edge-node-1.0-SNAPSHOT.jar /root/app.jar 5 | WORKDIR /root 6 | ENTRYPOINT ["sh","-c","java -jar app.jar --spring_application_name=$spring_application_name"] -------------------------------------------------------------------------------- /edge-node/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 6 | edge-computing 7 | cn.edu.scut 8 | 1.0-SNAPSHOT 9 | 10 | 4.0.0 11 | 12 | edge-node 13 | 14 | 15 | 17 16 | 17 17 | 18 | 19 | 20 | 21 | mysql 22 | mysql-connector-java 23 | 24 | 25 | com.alibaba 26 | druid-spring-boot-starter 27 | 28 | 29 | com.baomidou 30 | mybatis-plus-boot-starter 31 | 32 | 33 | 34 | com.alibaba.cloud 35 | spring-cloud-starter-alibaba-nacos-discovery 36 | 37 | 38 | com.alibaba.cloud 39 | spring-cloud-starter-alibaba-nacos-config 40 | 41 | 42 | org.springframework.cloud 43 | spring-cloud-starter-bootstrap 44 | 45 | 46 | 47 | 48 | org.springframework.boot 49 | spring-boot-starter-web 50 | 51 | 52 | 53 | org.springframework.boot 54 | spring-boot-starter-logging 55 | 56 | 57 | 58 | 59 | org.projectlombok 60 | lombok 61 | true 62 | 63 | 64 | 65 | 66 | org.springframework.boot 67 | spring-boot-starter-test 68 | test 69 | 70 | 71 | 72 | 73 | cn.edu.scut 74 | edge-api 75 | 1.0-SNAPSHOT 76 | 77 | 78 | org.springframework.cloud 79 | spring-cloud-loadbalancer 80 | 81 | 82 | org.apache.hadoop 83 | hadoop-common 84 | 85 | 86 | org.slf4j 87 | slf4j-api 88 | 89 | 90 | org.slf4j 91 | slf4j-reload4j 92 | 93 | 94 | 95 | 96 | 97 | org.apache.hadoop 98 | hadoop-hdfs 99 | 100 | 101 | 102 | org.apache.hadoop 103 | hadoop-hdfs-client 104 | 105 | 106 | 107 | cn.edu.scut 108 | edge-algorithm 109 | 1.0-SNAPSHOT 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | org.springframework.boot 119 | spring-boot-maven-plugin 120 | 121 | true 122 | true 123 | 124 | 125 | 126 | 127 | repackage 128 | 129 | 130 | 131 | 132 | 133 | 134 | -------------------------------------------------------------------------------- /edge-node/src/main/java/cn/edu/scut/EdgeNodeApp.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut; 2 | 3 | import cn.edu.scut.utils.SpringBeanUtils; 4 | import org.springframework.boot.SpringApplication; 5 | import org.springframework.boot.autoconfigure.SpringBootApplication; 6 | import org.springframework.cloud.client.discovery.EnableDiscoveryClient; 7 | import org.springframework.context.ConfigurableApplicationContext; 8 | import org.springframework.scheduling.annotation.EnableAsync; 9 | 10 | @SpringBootApplication 11 | @EnableDiscoveryClient 12 | @EnableAsync 13 | public class EdgeNodeApp { 14 | public static void main(String[] args) { 15 | ConfigurableApplicationContext context = SpringApplication.run(EdgeNodeApp.class, args); 16 | SpringBeanUtils.setApplicationContext(context); 17 | } 18 | } -------------------------------------------------------------------------------- /edge-node/src/main/java/cn/edu/scut/bean/EdgeNodeSystem.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut.bean; 2 | 3 | import cn.edu.scut.queue.ExecutionQueue; 4 | import cn.edu.scut.queue.TransmissionQueue; 5 | import lombok.Getter; 6 | import lombok.Setter; 7 | import org.springframework.stereotype.Component; 8 | 9 | import java.util.Map; 10 | 11 | @Getter 12 | @Setter 13 | @Component 14 | public class EdgeNodeSystem { 15 | private EdgeNode edgeNode; 16 | private Map linkMap; 17 | private ExecutionQueue executionQueue; 18 | private Map transmissionQueueMap; 19 | // threshold 20 | private float executionQueueThreshold; 21 | } 22 | -------------------------------------------------------------------------------- /edge-node/src/main/java/cn/edu/scut/config/Config.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut.config; 2 | 3 | import lombok.extern.slf4j.Slf4j; 4 | import org.springframework.beans.factory.annotation.Value; 5 | import org.springframework.cloud.client.loadbalancer.LoadBalanced; 6 | import org.springframework.context.annotation.Bean; 7 | import org.springframework.context.annotation.Configuration; 8 | import org.springframework.web.client.RestTemplate; 9 | 10 | import java.util.Random; 11 | 12 | @Configuration 13 | @Slf4j 14 | public class Config { 15 | 16 | @Value("${edgeComputing.reliabilitySeed}") 17 | Integer reliabilitySeed; 18 | 19 | @Value("${edgeComputing.schedulerSeed}") 20 | Integer schedulerSeed; 21 | 22 | @Value("${spring.application.name}") 23 | String name; 24 | 25 | @Bean 26 | public Random reliabilityRandom() { 27 | Integer id = Integer.parseInt(name.split("-")[2]); 28 | return new Random(reliabilitySeed + id); 29 | } 30 | 31 | @Bean 32 | public Random schedulerRandom() { 33 | Integer id = Integer.parseInt(name.split("-")[2]); 34 | return new Random(schedulerSeed + id); 35 | } 36 | 37 | @Bean 38 | @LoadBalanced 39 | RestTemplate restTemplate() { 40 | return new RestTemplate(); 41 | } 42 | } -------------------------------------------------------------------------------- /edge-node/src/main/java/cn/edu/scut/config/HadoopConfig.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut.config; 2 | 3 | import org.apache.hadoop.fs.FileSystem; 4 | import org.springframework.beans.factory.annotation.Value; 5 | import org.springframework.context.annotation.Bean; 6 | import org.springframework.context.annotation.Configuration; 7 | 8 | import java.io.IOException; 9 | import java.net.URI; 10 | import java.net.URISyntaxException; 11 | 12 | @Configuration 13 | public class HadoopConfig { 14 | 15 | @Value("${hadoop.hdfs.url}") 16 | private String hdfsUrl; 17 | 18 | @Bean 19 | public FileSystem fileSystem() { 20 | URI uri; 21 | try { 22 | uri = new URI(hdfsUrl); 23 | } catch (URISyntaxException e) { 24 | throw new RuntimeException(e); 25 | } 26 | var configuration = new org.apache.hadoop.conf.Configuration(); 27 | String user = "hongcai"; 28 | FileSystem fileSystem; 29 | try { 30 | fileSystem = FileSystem.get(uri, configuration, user); 31 | } catch (IOException | InterruptedException e) { 32 | throw new RuntimeException(e); 33 | } 34 | return fileSystem; 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /edge-node/src/main/java/cn/edu/scut/controller/CmdController.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut.controller; 2 | 3 | import cn.edu.scut.service.EdgeNodeSystemService; 4 | import org.springframework.web.bind.annotation.GetMapping; 5 | import org.springframework.web.bind.annotation.RequestMapping; 6 | import org.springframework.web.bind.annotation.RequestMethod; 7 | import org.springframework.web.bind.annotation.RestController; 8 | 9 | import javax.annotation.Resource; 10 | 11 | @RestController 12 | @RequestMapping(value = "/cmd", method = {RequestMethod.GET, RequestMethod.POST}) 13 | public class CmdController { 14 | 15 | @Resource 16 | EdgeNodeSystemService edgeNodeSystemService; 17 | 18 | @GetMapping("/init") 19 | public String init(){ 20 | edgeNodeSystemService.init(); 21 | return "success"; 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /edge-node/src/main/java/cn/edu/scut/controller/EdgeNodeController.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut.controller; 2 | 3 | import cn.edu.scut.bean.EdgeNode; 4 | import cn.edu.scut.bean.RatcVo; 5 | import cn.edu.scut.bean.Task; 6 | import cn.edu.scut.service.EdgeNodeSystemService; 7 | import cn.edu.scut.thread.ExecutionRunnable; 8 | import lombok.extern.apachecommons.CommonsLog; 9 | import org.springframework.web.bind.annotation.*; 10 | 11 | import javax.annotation.Resource; 12 | 13 | @RestController 14 | @RequestMapping(value = "/edgeNode", method = {RequestMethod.GET, RequestMethod.POST}) 15 | @CommonsLog 16 | public class EdgeNodeController { 17 | 18 | @Resource 19 | private EdgeNodeSystemService edgeNodeSystemService; 20 | 21 | @PostMapping("/task") 22 | public String receiveEdgeNodeTask(@RequestBody Task task) { 23 | log.info("receive task from edge node: " + task.getSource()); 24 | edgeNodeSystemService.processEdgeNodeTask(task); 25 | return "success"; 26 | } 27 | 28 | @GetMapping("/queue") 29 | public Integer queue() { 30 | return edgeNodeSystemService.getEdgeNodeSystem().getExecutionQueue().getSize(); 31 | } 32 | 33 | @GetMapping("/info") 34 | public EdgeNode info() { 35 | return edgeNodeSystemService.getEdgeNodeSystem().getEdgeNode(); 36 | } 37 | 38 | @GetMapping("/waitingTime") 39 | public Long waitingTime() { 40 | var queue = edgeNodeSystemService.getEdgeNodeSystem().getExecutionQueue().getExecutor().getQueue(); 41 | long waitingTime = 0; 42 | for (Runnable runnable : queue) { 43 | var r = (ExecutionRunnable) runnable; 44 | waitingTime += r.getTask().getExecutionTime(); 45 | } 46 | return waitingTime; 47 | } 48 | 49 | @GetMapping("/ratc") 50 | public RatcVo ratc() { 51 | var queue = edgeNodeSystemService.getEdgeNodeSystem().getExecutionQueue().getExecutor().getQueue(); 52 | long waitingTime = 0; 53 | for (Runnable runnable : queue) { 54 | var r = (ExecutionRunnable) runnable; 55 | waitingTime += r.getTask().getExecutionTime(); 56 | } 57 | var res = new RatcVo(); 58 | res.setWaitingTime(waitingTime); 59 | var edgeNode = edgeNodeSystemService.getEdgeNodeSystem().getEdgeNode(); 60 | res.setExecutionFailureRate(edgeNode.getExecutionFailureRate()); 61 | res.setCapacity(edgeNode.getCapacity()); 62 | res.setEdgeId(edgeNode.getName()); 63 | return res; 64 | } 65 | 66 | @GetMapping("/available") 67 | public Integer avail() { 68 | int queueSize = edgeNodeSystemService.getEdgeNodeSystem().getExecutionQueue().getSize(); 69 | if (queueSize < edgeNodeSystemService.getEdgeNodeSystem().getExecutionQueueThreshold()) { 70 | return 1; 71 | } else { 72 | return 0; 73 | } 74 | } 75 | } -------------------------------------------------------------------------------- /edge-node/src/main/java/cn/edu/scut/controller/FileController.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut.controller; 2 | 3 | import lombok.extern.slf4j.Slf4j; 4 | import org.springframework.http.MediaType; 5 | import org.springframework.util.StreamUtils; 6 | import org.springframework.web.bind.annotation.PostMapping; 7 | import org.springframework.web.bind.annotation.RestController; 8 | 9 | import javax.servlet.ServletException; 10 | import javax.servlet.http.HttpServletRequest; 11 | import javax.servlet.http.Part; 12 | import java.io.IOException; 13 | import java.io.InputStream; 14 | import java.nio.file.Files; 15 | import java.nio.file.Paths; 16 | 17 | @RestController 18 | @Slf4j 19 | public class FileController { 20 | 21 | @PostMapping(value = "/upload", consumes = MediaType.MULTIPART_FORM_DATA_VALUE, produces = MediaType.TEXT_PLAIN_VALUE) 22 | public String upload(HttpServletRequest request) throws IOException, ServletException { 23 | for (Part part : request.getParts()) { 24 | log.info("content type: {}", part.getContentType()); //text/plain 25 | log.info("parameter name: {}", part.getName()); // actor 26 | log.info("file name: {}", part.getSubmittedFileName()); // actor.param 27 | log.info("file size: {}", part.getSize()); 28 | log.info("save file"); 29 | InputStream inputStream = part.getInputStream(); 30 | var path = Paths.get("results/model/edge-node-1/"+ part.getSubmittedFileName()); 31 | Files.createDirectories(path.getParent()); 32 | Files.write(path, StreamUtils.copyToByteArray(inputStream)); 33 | } 34 | return "success"; 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /edge-node/src/main/java/cn/edu/scut/controller/ModelController.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut.controller; 2 | 3 | import cn.edu.scut.agent.IMAAgent; 4 | import lombok.extern.slf4j.Slf4j; 5 | import org.springframework.beans.factory.annotation.Autowired; 6 | import org.springframework.web.bind.annotation.GetMapping; 7 | import org.springframework.web.bind.annotation.PathVariable; 8 | import org.springframework.web.bind.annotation.PostMapping; 9 | import org.springframework.web.bind.annotation.RestController; 10 | 11 | import javax.servlet.ServletException; 12 | import javax.servlet.http.HttpServletRequest; 13 | import javax.servlet.http.Part; 14 | import java.io.IOException; 15 | 16 | @RestController 17 | @Slf4j 18 | public class ModelController { 19 | 20 | // heuristic 21 | @Autowired(required = false) 22 | private IMAAgent agent; 23 | 24 | @GetMapping("/updateParam/{flag}") 25 | public String updateParam(@PathVariable("flag") String flag) { 26 | agent.loadHdfsModel(flag); 27 | return "success"; 28 | } 29 | 30 | @PostMapping("/updateModel") 31 | public String updateStreamParam(HttpServletRequest request) throws IOException, ServletException { 32 | for (Part part : request.getParts()) { 33 | log.info("content type: {}", part.getContentType()); //text/plain 34 | log.info("parameter name: {}", part.getName()); // actor 35 | log.info("file name: {}", part.getSubmittedFileName()); // ctor.param 36 | log.info("file size: {}", part.getSize()); 37 | agent.loadSteamModel(part.getInputStream(), part.getSubmittedFileName()); 38 | log.info("update model completed!"); 39 | } 40 | return "success"; 41 | } 42 | 43 | @GetMapping("/loadModel/{flag}") 44 | public String loadLocalModel(@PathVariable("flag") String flag) { 45 | agent.loadModel(flag); 46 | return "success"; 47 | } 48 | } -------------------------------------------------------------------------------- /edge-node/src/main/java/cn/edu/scut/controller/UserController.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut.controller; 2 | 3 | import cn.edu.scut.bean.Task; 4 | import cn.edu.scut.service.EdgeNodeSystemService; 5 | import lombok.extern.apachecommons.CommonsLog; 6 | import org.springframework.web.bind.annotation.*; 7 | 8 | import javax.annotation.Resource; 9 | import java.time.LocalDateTime; 10 | 11 | @RestController 12 | @RequestMapping(value = "/user", method = {RequestMethod.GET, RequestMethod.POST}) 13 | @CommonsLog 14 | public class UserController { 15 | @Resource 16 | EdgeNodeSystemService edgeNodeSystemService; 17 | 18 | @PostMapping("/task") 19 | public String receiveUserTask(@RequestBody Task task) { 20 | log.info("receive task from user"); 21 | task.setArrivalTime(LocalDateTime.now()); 22 | edgeNodeSystemService.processUserTask(task); 23 | return "success"; 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /edge-node/src/main/java/cn/edu/scut/queue/ExecutionQueue.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut.queue; 2 | 3 | import cn.edu.scut.bean.Task; 4 | import cn.edu.scut.thread.ExecutionRunnable; 5 | import cn.edu.scut.utils.SpringBeanUtils; 6 | import lombok.Getter; 7 | 8 | import java.util.concurrent.LinkedBlockingQueue; 9 | import java.util.concurrent.ThreadPoolExecutor; 10 | import java.util.concurrent.TimeUnit; 11 | 12 | public class ExecutionQueue { 13 | // execution queue based on the queue in thread pool 14 | @Getter 15 | private ThreadPoolExecutor executor = new ThreadPoolExecutor(1, 1, 0, TimeUnit.MILLISECONDS, new LinkedBlockingQueue<>()); 16 | 17 | public void add(Task task) { 18 | ExecutionRunnable runnable = SpringBeanUtils.applicationContext.getBean(ExecutionRunnable.class); 19 | runnable.setTask(task); 20 | executor.execute(runnable); 21 | } 22 | public int getSize(){ 23 | return executor.getQueue().size(); 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /edge-node/src/main/java/cn/edu/scut/queue/TransmissionQueue.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut.queue; 2 | 3 | import cn.edu.scut.bean.Task; 4 | import cn.edu.scut.thread.TransmissionRunnable; 5 | import cn.edu.scut.utils.SpringBeanUtils; 6 | 7 | import java.util.concurrent.LinkedBlockingQueue; 8 | import java.util.concurrent.ThreadPoolExecutor; 9 | import java.util.concurrent.TimeUnit; 10 | 11 | 12 | public class TransmissionQueue { 13 | // transmission queue based on the queue in thread pool 14 | ThreadPoolExecutor thread = new ThreadPoolExecutor(1, 1, 0, TimeUnit.MILLISECONDS, new LinkedBlockingQueue<>()); 15 | 16 | public void add(Task task) { 17 | TransmissionRunnable runnable = SpringBeanUtils.applicationContext.getBean(TransmissionRunnable.class); 18 | runnable.setTask(task); 19 | thread.execute(runnable); 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /edge-node/src/main/java/cn/edu/scut/scheduler/DRLScheduler.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut.scheduler; 2 | 3 | 4 | import cn.edu.scut.agent.IMAAgent; 5 | import cn.edu.scut.bean.Task; 6 | import cn.edu.scut.service.TaskService; 7 | import cn.edu.scut.service.TransitionService; 8 | import cn.edu.scut.util.ArrayUtils; 9 | import lombok.extern.slf4j.Slf4j; 10 | import org.springframework.beans.factory.annotation.Autowired; 11 | import org.springframework.beans.factory.annotation.Value; 12 | import org.springframework.cloud.context.config.annotation.RefreshScope; 13 | import org.springframework.stereotype.Service; 14 | 15 | @Service 16 | @Slf4j 17 | @RefreshScope 18 | public class DRLScheduler implements IScheduler { 19 | 20 | // heuristic method 21 | @Autowired(required = false) 22 | private IMAAgent agent; 23 | 24 | @Autowired 25 | private TransitionService transitionService; 26 | 27 | @Autowired 28 | private TaskService taskService; 29 | 30 | @Value("${edgeComputing.edgeNodeNumber}") 31 | private int agentNumber; 32 | 33 | @Value("${spring.application.name}") 34 | public String name; 35 | 36 | @Override 37 | public String selectAction(Long taskId) { 38 | Task task = taskService.getById(taskId); 39 | var availAction = new int[agentNumber + 1]; 40 | availAction[agentNumber] = 0; 41 | for (int i = 1; i <= agentNumber; i++) { 42 | availAction[i - 1] = 1; 43 | } 44 | String s = ArrayUtils.arrayToString(availAction); 45 | task.setAvailAction(s); 46 | taskService.updateById(task); 47 | var state = transitionService.getState(taskId); 48 | int i = agent.selectAction(state, availAction, false) + 1; 49 | log.info("drl select action {}", i); 50 | return String.format("edge-node-%d", i); 51 | } 52 | } -------------------------------------------------------------------------------- /edge-node/src/main/java/cn/edu/scut/scheduler/IScheduler.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut.scheduler; 2 | 3 | public interface IScheduler { 4 | String selectAction(Long taskId); 5 | } 6 | -------------------------------------------------------------------------------- /edge-node/src/main/java/cn/edu/scut/scheduler/RandomScheduler.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut.scheduler; 2 | 3 | import lombok.Setter; 4 | import org.springframework.beans.factory.annotation.Autowired; 5 | import org.springframework.beans.factory.annotation.Value; 6 | import org.springframework.context.annotation.Lazy; 7 | import org.springframework.stereotype.Component; 8 | 9 | import java.util.ArrayList; 10 | import java.util.Random; 11 | 12 | @Lazy 13 | @Component 14 | @Setter 15 | public class RandomScheduler implements IScheduler { 16 | @Autowired 17 | private Random schedulerRandom; 18 | 19 | @Value("${edgeComputing.edgeNodeNumber}") 20 | private int edgeNodeNumber; 21 | 22 | @Override 23 | public String selectAction(Long taskId) { 24 | var services = new ArrayList(); 25 | for (int i = 1; i <= edgeNodeNumber; i++) { 26 | services.add(String.format("edge-node-%d", i)); 27 | } 28 | int index = schedulerRandom.nextInt(services.size()); 29 | return services.get(index); 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /edge-node/src/main/java/cn/edu/scut/scheduler/ReactiveScheduler.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut.scheduler; 2 | 3 | import cn.edu.scut.bean.EdgeNodeSystem; 4 | import org.springframework.beans.factory.annotation.Autowired; 5 | import org.springframework.beans.factory.annotation.Value; 6 | import org.springframework.context.annotation.Lazy; 7 | import org.springframework.stereotype.Component; 8 | import org.springframework.web.client.RestTemplate; 9 | 10 | import java.util.*; 11 | 12 | @Lazy 13 | @Component 14 | public class ReactiveScheduler implements IScheduler { 15 | 16 | @Value("${spring.application.name}") 17 | public String name; 18 | 19 | @Autowired 20 | private EdgeNodeSystem edgeNodeSystem; 21 | 22 | @Autowired 23 | private RestTemplate restTemplate; 24 | 25 | @Value("${edgeComputing.edgeNodeNumber}") 26 | private int edgeNodeNumber; 27 | 28 | @Value("${heuristic.queue-coef}") 29 | private float queueCoef; 30 | 31 | @Autowired 32 | private Random schedulerRandom; 33 | 34 | @Override 35 | public String selectAction(Long taskId) { 36 | if (edgeNodeSystem.getExecutionQueue().getSize() < edgeNodeSystem.getExecutionQueueThreshold() * queueCoef) { 37 | return name; 38 | } 39 | var edgeNodeIds = new ArrayList(); 40 | for (int i = 1; i <= edgeNodeNumber; i++) { 41 | edgeNodeIds.add(String.format("edge-node-%d", i)); 42 | } 43 | var selectedNodes = new HashSet(); 44 | // while true 45 | for (int i = 0; i < 1000; i++) { 46 | var edgeNodeId = edgeNodeIds.get(schedulerRandom.nextInt(edgeNodeIds.size())); 47 | if (edgeNodeIds.size() > 2 && edgeNodeId.equals(name)) { 48 | continue; 49 | } 50 | selectedNodes.add(edgeNodeId); 51 | if (selectedNodes.size() == 2) { 52 | break; 53 | } 54 | } 55 | var queue = new PriorityQueue<>(Comparator.comparingInt((Map o) -> (int) o.get("queue"))); 56 | for (String edgeNodeId : selectedNodes) { 57 | var info = new HashMap(); 58 | var url = String.format("http://%s/edgeNode/queue", edgeNodeId); 59 | var queueSize = restTemplate.getForObject(url, Integer.class); 60 | info.put("queue", queueSize); 61 | info.put("edgeId", edgeNodeId); 62 | queue.add(info); 63 | } 64 | return (String) queue.poll().get("edgeId"); 65 | } 66 | } -------------------------------------------------------------------------------- /edge-node/src/main/java/cn/edu/scut/scheduler/ReliabilityTwoChoice.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut.scheduler; 2 | 3 | import cn.edu.scut.bean.RatcVo; 4 | import cn.edu.scut.service.TaskService; 5 | import lombok.extern.slf4j.Slf4j; 6 | import org.springframework.beans.factory.annotation.Autowired; 7 | import org.springframework.beans.factory.annotation.Value; 8 | import org.springframework.context.annotation.Lazy; 9 | import org.springframework.stereotype.Component; 10 | import org.springframework.web.client.RestTemplate; 11 | 12 | import java.util.ArrayList; 13 | import java.util.HashSet; 14 | import java.util.Random; 15 | 16 | @Slf4j 17 | @Lazy 18 | @Component 19 | public class ReliabilityTwoChoice implements IScheduler { 20 | 21 | @Autowired 22 | private Random schedulerRandom; 23 | 24 | @Autowired 25 | private RestTemplate restTemplate; 26 | 27 | @Autowired 28 | private TaskService taskService; 29 | 30 | @Value("${edgeComputing.edgeNodeNumber}") 31 | private int edgeNodeNumber; 32 | 33 | @Override 34 | public String selectAction(Long taskId) { 35 | var task = taskService.getById(taskId); 36 | 37 | var edgeNodeIds = new ArrayList(); 38 | for (int i = 1; i <= edgeNodeNumber; i++) { 39 | edgeNodeIds.add(String.format("edge-node-%d", i)); 40 | } 41 | var selectedNodes = new HashSet(); 42 | // while true 43 | for (int i = 0; i < 1000; i++) { 44 | selectedNodes.add(edgeNodeIds.get(schedulerRandom.nextInt(edgeNodeIds.size()))); 45 | if (selectedNodes.size() == 2) { 46 | break; 47 | } 48 | } 49 | 50 | var selectEdgeNodeInfo = new ArrayList(); 51 | for (String edgeNodeId : selectedNodes) { 52 | var url1 = String.format("http://%s/edgeNode/ratc", edgeNodeId); 53 | var ratcVo = restTemplate.getForObject(url1, RatcVo.class); 54 | selectEdgeNodeInfo.add(ratcVo); 55 | } 56 | 57 | boolean flag = true; 58 | for (RatcVo ratcVo : selectEdgeNodeInfo) { 59 | long executionTime = (long) (task.getCpuCycle().doubleValue() / ratcVo.getCapacity().doubleValue() * 1000); 60 | ratcVo.setTotalTime(executionTime + ratcVo.getWaitingTime()); 61 | if (ratcVo.getTotalTime() > task.getDeadline()) { 62 | flag = false; 63 | } 64 | } 65 | if (flag) { 66 | var failureRate1 = selectEdgeNodeInfo.get(0).getExecutionFailureRate(); 67 | var failureRate2 = selectEdgeNodeInfo.get(1).getExecutionFailureRate(); 68 | if (failureRate1 < failureRate2) { 69 | return selectEdgeNodeInfo.get(0).getEdgeId(); 70 | } else { 71 | return selectEdgeNodeInfo.get(1).getEdgeId(); 72 | } 73 | } else { 74 | var time1 = selectEdgeNodeInfo.get(0).getTotalTime(); 75 | var time2 = selectEdgeNodeInfo.get(1).getTotalTime(); 76 | if (time1 < time2) { 77 | return selectEdgeNodeInfo.get(0).getEdgeId(); 78 | } else { 79 | return selectEdgeNodeInfo.get(1).getEdgeId(); 80 | } 81 | } 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /edge-node/src/main/java/cn/edu/scut/service/EdgeNodeSystemService.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut.service; 2 | 3 | import cn.edu.scut.bean.*; 4 | import cn.edu.scut.queue.ExecutionQueue; 5 | import cn.edu.scut.queue.TransmissionQueue; 6 | import cn.edu.scut.scheduler.*; 7 | import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; 8 | import lombok.Getter; 9 | import lombok.Setter; 10 | import lombok.extern.slf4j.Slf4j; 11 | import org.springframework.beans.factory.annotation.Autowired; 12 | import org.springframework.beans.factory.annotation.Value; 13 | import org.springframework.cloud.context.config.annotation.RefreshScope; 14 | import org.springframework.context.annotation.Lazy; 15 | import org.springframework.scheduling.annotation.Async; 16 | import org.springframework.stereotype.Service; 17 | import org.springframework.web.client.RestTemplate; 18 | 19 | import java.time.LocalDateTime; 20 | import java.util.HashMap; 21 | 22 | @Setter 23 | @Getter 24 | @Service 25 | @Slf4j 26 | @RefreshScope 27 | public class EdgeNodeSystemService { 28 | 29 | @Value("${spring.application.name}") 30 | private String name; 31 | 32 | @Value("${edgeComputing.seed}") 33 | private Integer seed; 34 | 35 | @Value("${edgeComputing.cpuCapacity}") 36 | private Integer cpuCapacity; 37 | 38 | // refresh 39 | @Value("${edgeComputing.scheduler}") 40 | private String scheduler; 41 | 42 | @Value("${edgeComputing.minCpuCore}") 43 | private int minCpuCore; 44 | 45 | @Autowired 46 | private RestTemplate restTemplate; 47 | 48 | @Autowired 49 | private EdgeNodeSystem edgeNodeSystem; 50 | 51 | @Lazy 52 | @Autowired 53 | private DRLScheduler DRLScheduler; 54 | 55 | @Lazy 56 | @Autowired 57 | private RandomScheduler randomScheduler; 58 | 59 | @Lazy 60 | @Autowired 61 | private ReliabilityTwoChoice reliabilityTwoChoice; 62 | 63 | @Lazy 64 | @Autowired 65 | private ReactiveScheduler reactiveScheduler; 66 | 67 | @Autowired 68 | private EdgeNodeService edgeNodeService; 69 | 70 | @Autowired 71 | private LinkService linkService; 72 | 73 | @Value("${edgeComputing.queueCoef}") 74 | private float queueCoef; 75 | 76 | @Async 77 | public void processUserTask(Task task) { 78 | String service = switch (scheduler) { 79 | case "rl" -> DRLScheduler.selectAction(task.getId()); 80 | case "random" -> randomScheduler.selectAction(task.getId()); 81 | case "reactive" -> reactiveScheduler.selectAction(task.getId()); 82 | case "reliability-two-choice" -> reliabilityTwoChoice.selectAction(task.getId()); 83 | default -> throw new RuntimeException("no scheduler"); 84 | }; 85 | if (service.equals(name)) { 86 | task.setBeginExecutionTime(LocalDateTime.now()); 87 | task.setEndTransmissionTime(LocalDateTime.now()); 88 | task.setTransmissionTime(0L); 89 | task.setTransmissionWaitingTime(0L); 90 | task.setDestination(name); 91 | processEdgeNodeTask(task); 92 | } else { 93 | Link link = edgeNodeSystem.getLinkMap().get(service); 94 | Double transmissionTime = task.getTaskSize() / link.getTransmissionRate() * 1000; 95 | task.setTransmissionTime(transmissionTime.longValue()); 96 | task.setDestination(service); 97 | edgeNodeSystem.getTransmissionQueueMap().get(service).add(task); 98 | } 99 | } 100 | 101 | public void processEdgeNodeTask(Task task) { 102 | Double executionTime = task.getCpuCycle().doubleValue() / edgeNodeSystem.getEdgeNode().getCapacity().doubleValue() * 1000; 103 | task.setExecutionTime(executionTime.longValue()); 104 | edgeNodeSystem.getExecutionQueue().add(task); 105 | } 106 | 107 | public void init() { 108 | var edgeNodeConfig = edgeNodeService.getOne(new QueryWrapper().eq("name", name)); 109 | var id = Integer.parseInt(name.split("-")[2]); 110 | EdgeNode edgeNode = new EdgeNode(); 111 | edgeNode.setId(id); 112 | edgeNode.setName(name); 113 | edgeNode.setCpuNum(edgeNodeConfig.getCpuNum()); 114 | edgeNode.setExecutionFailureRate(edgeNodeConfig.getExecutionFailureRate()); 115 | edgeNode.setTaskRate(edgeNodeConfig.getTaskRate()); 116 | edgeNode.setCapacity(edgeNodeConfig.getCpuNum() * Constants.Giga.value * cpuCapacity); 117 | edgeNodeSystem.setEdgeNode(edgeNode); 118 | edgeNodeSystem.setExecutionQueue(new ExecutionQueue()); 119 | 120 | var transmissionQueueMap = new HashMap(); 121 | var linkMap = new HashMap(); 122 | var links = linkService.list(new QueryWrapper().eq("source", name)); 123 | for (Link link : links) { 124 | transmissionQueueMap.put(link.getDestination(), new TransmissionQueue()); 125 | linkMap.put(link.getDestination(), link); 126 | } 127 | edgeNodeSystem.setTransmissionQueueMap(transmissionQueueMap); 128 | edgeNodeSystem.setLinkMap(linkMap); 129 | // availAction 130 | edgeNodeSystem.setExecutionQueueThreshold(Float.valueOf(edgeNode.getCpuNum()) / (float) minCpuCore * queueCoef); 131 | log.info("load edge nodes and links configuration completed"); 132 | log.info("{} edge config: {}", name, edgeNode); 133 | log.info("{} link config:{}", name, links); 134 | } 135 | } 136 | -------------------------------------------------------------------------------- /edge-node/src/main/java/cn/edu/scut/thread/ExecutionRunnable.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut.thread; 2 | 3 | import cn.edu.scut.bean.EdgeNodeSystem; 4 | import cn.edu.scut.bean.Task; 5 | import cn.edu.scut.bean.TaskStatus; 6 | import cn.edu.scut.service.TaskService; 7 | import lombok.Getter; 8 | import lombok.Setter; 9 | import org.springframework.beans.factory.annotation.Autowired; 10 | import org.springframework.cloud.context.config.annotation.RefreshScope; 11 | import org.springframework.context.annotation.Scope; 12 | import org.springframework.stereotype.Component; 13 | 14 | import java.time.Duration; 15 | import java.time.LocalDateTime; 16 | import java.util.Random; 17 | 18 | @Component 19 | @Setter 20 | @Scope("prototype") 21 | @RefreshScope 22 | public class ExecutionRunnable implements Runnable { 23 | 24 | // prototype 25 | @Getter 26 | private Task task; 27 | @Autowired 28 | private TaskService taskService; 29 | 30 | @Autowired 31 | private EdgeNodeSystem edgeNodeSystem; 32 | 33 | @Autowired 34 | private Random reliabilityRandom; 35 | 36 | @Override 37 | public void run() { 38 | task.setBeginExecutionTime(LocalDateTime.now()); 39 | task.setExecutionWaitingTime(Duration.between(task.getEndTransmissionTime(), task.getBeginExecutionTime()).toMillis()); 40 | // drop the task without wasting the resource. 41 | long estimatedTotalTime = task.getTransmissionWaitingTime() + task.getTransmissionTime() + task.getExecutionWaitingTime() + task.getExecutionTime(); 42 | if (estimatedTotalTime > task.getDeadline()) { 43 | task.setStatus(TaskStatus.DROP); 44 | taskService.updateById(task); 45 | return; 46 | } 47 | try { 48 | Thread.sleep(task.getExecutionTime()); 49 | 50 | } catch (InterruptedException e) { 51 | e.printStackTrace(); 52 | } 53 | task.setEndExecutionTime(LocalDateTime.now()); 54 | // reliability 55 | double reliability = Math.exp(-task.getExecutionTime() / 1000.0 * edgeNodeSystem.getEdgeNode().getExecutionFailureRate()); 56 | if (reliabilityRandom.nextDouble() > reliability) { 57 | task.setStatus(TaskStatus.EXECUTION_FAILURE); 58 | } 59 | else { 60 | task.setStatus(TaskStatus.SUCCESS); 61 | } 62 | taskService.updateById(task); 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /edge-node/src/main/java/cn/edu/scut/thread/TransmissionRunnable.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut.thread; 2 | 3 | import cn.edu.scut.bean.EdgeNodeSystem; 4 | import cn.edu.scut.bean.Link; 5 | import cn.edu.scut.bean.Task; 6 | import cn.edu.scut.bean.TaskStatus; 7 | import cn.edu.scut.service.TaskService; 8 | import lombok.Setter; 9 | import lombok.extern.apachecommons.CommonsLog; 10 | import org.springframework.beans.factory.annotation.Autowired; 11 | import org.springframework.cloud.context.config.annotation.RefreshScope; 12 | import org.springframework.context.annotation.Scope; 13 | import org.springframework.stereotype.Component; 14 | import org.springframework.web.client.RestTemplate; 15 | 16 | import java.time.Duration; 17 | import java.time.LocalDateTime; 18 | import java.util.Random; 19 | 20 | @Component 21 | @Scope("prototype") 22 | @Setter 23 | @RefreshScope 24 | @CommonsLog 25 | public class TransmissionRunnable implements Runnable { 26 | // prototype 27 | private Task task; 28 | 29 | @Autowired 30 | private EdgeNodeSystem edgeNodeSystem; 31 | 32 | @Autowired 33 | private Random reliabilityRandom; 34 | 35 | @Autowired 36 | private RestTemplate restTemplate; 37 | 38 | @Autowired 39 | private TaskService taskService; 40 | 41 | @Override 42 | public void run() { 43 | task.setBeginTransmissionTime(LocalDateTime.now()); 44 | task.setTransmissionWaitingTime(Duration.between(task.getArrivalTime(), task.getBeginTransmissionTime()).toMillis()); 45 | try { 46 | Thread.sleep(task.getTransmissionTime()); 47 | } catch (InterruptedException e) { 48 | e.printStackTrace(); 49 | } 50 | task.setEndTransmissionTime(LocalDateTime.now()); 51 | // reliability 52 | String destination = task.getDestination(); 53 | Link link = edgeNodeSystem.getLinkMap().get(destination); 54 | Double transmissionFailureRate = link.getTransmissionFailureRate(); 55 | double reliability = Math.exp(-task.getTransmissionTime() / 1000.0 * transmissionFailureRate); 56 | if (reliabilityRandom.nextDouble() > reliability) { 57 | task.setStatus(TaskStatus.TRANSMISSION_FAILURE); 58 | // fixed null type exception 59 | task.setExecutionWaitingTime(0L); 60 | task.setExecutionTime(0L); 61 | taskService.updateById(task); 62 | } else { 63 | // seed task to other edge node 64 | String url = String.format("http://%s/edgeNode/task", destination); 65 | restTemplate.postForObject(url, task, String.class); 66 | } 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /edge-node/src/main/java/cn/edu/scut/utils/SpringBeanUtils.java: -------------------------------------------------------------------------------- 1 | package cn.edu.scut.utils; 2 | 3 | import org.springframework.context.ApplicationContext; 4 | 5 | public class SpringBeanUtils { 6 | public static ApplicationContext applicationContext; 7 | 8 | public static void setApplicationContext(ApplicationContext applicationContext) { 9 | SpringBeanUtils.applicationContext = applicationContext; 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /edge-node/src/main/resources/META-INF/MANIFEST.MF: -------------------------------------------------------------------------------- 1 | Manifest-Version: 1.0 2 | Main-Class: cn.edu.scut.EdgeNodeApp 3 | 4 | -------------------------------------------------------------------------------- /edge-node/src/main/resources/bootstrap.yaml: -------------------------------------------------------------------------------- 1 | spring: 2 | application: 3 | name: ${spring_application_name:edge-node-1} 4 | profiles: 5 | # active: edge-computing,reactive 6 | # active: edge-computing,reliability-two-choice 7 | # active: edge-computing,mappo 8 | # active: edge-computing,masac 9 | active: edge-computing,random 10 | cloud: 11 | nacos: 12 | discovery: 13 | server-addr: 222.201.187.50:30848 14 | watch-delay: 3000 15 | config: 16 | server-addr: 222.201.187.50:30848 17 | prefix: application 18 | file-extension: yaml -------------------------------------------------------------------------------- /k8s/edge/edge-controller.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: Pod 3 | metadata: 4 | namespace: edge-computing 5 | name: edge-controller 6 | labels: 7 | app: edge-controller 8 | spec: 9 | containers: 10 | - name: edge-controller 11 | image: edge-controller:v1.0 12 | imagePullPolicy: IfNotPresent 13 | restartPolicy: Always -------------------------------------------------------------------------------- /k8s/edge/edge-experiment.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: Pod 3 | metadata: 4 | namespace: edge-computing 5 | name: edge-experiment 6 | labels: 7 | app: edge-experiment 8 | spec: 9 | containers: 10 | - name: edge-experiment 11 | image: edge-experiment:v1.0 12 | imagePullPolicy: IfNotPresent 13 | restartPolicy: Always -------------------------------------------------------------------------------- /k8s/edge/edge-node.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: Pod 3 | metadata: 4 | namespace: edge-computing 5 | name: $spring_application_name 6 | labels: 7 | app: $spring_application_name 8 | spec: 9 | containers: 10 | - name: $spring_application_name 11 | image: edge-node:v1.0 12 | imagePullPolicy: IfNotPresent 13 | env: 14 | - name: spring_application_name 15 | value: $spring_application_name 16 | restartPolicy: Always -------------------------------------------------------------------------------- /k8s/mysql.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: Pod 3 | metadata: 4 | labels: 5 | run: mysql 6 | name: mysql 7 | spec: 8 | containers: 9 | - env: 10 | - name: MYSQL_ROOT_PASSWORD 11 | value: "123456" 12 | image: mysql 13 | name: mysql 14 | volumeMounts: 15 | - mountPath: /var/lib/mysql 16 | name: data 17 | - mountPath: /etc/mysql/conf.d 18 | name: conf 19 | resources: { } 20 | volumes: 21 | - name: data 22 | nfs: 23 | path: /home/hongcai/lhc_data/mysql/data 24 | server: 222.201.187.50 25 | - name: conf 26 | nfs: 27 | path: /home/hongcai/lhc_data/mysql/conf 28 | server: 222.201.187.50 29 | dnsPolicy: ClusterFirst 30 | restartPolicy: Always 31 | 32 | --- 33 | apiVersion: v1 34 | kind: Service 35 | metadata: 36 | labels: 37 | run: mysql 38 | name: mysql 39 | spec: 40 | ports: 41 | - port: 3306 42 | protocol: TCP 43 | nodePort: 30306 44 | targetPort: 3306 45 | selector: 46 | run: mysql 47 | type: NodePort -------------------------------------------------------------------------------- /k8s/nacos.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: Pod 3 | metadata: 4 | labels: 5 | run: nacos 6 | name: nacos 7 | spec: 8 | containers: 9 | - env: 10 | - name: MODE 11 | value: standalone 12 | image: nacos/nacos-server:v2.1.2 13 | name: nacos 14 | volumeMounts: 15 | - mountPath: /home/nacos/data 16 | name: data 17 | resources: { } 18 | volumes: 19 | - name: data 20 | nfs: 21 | path: /home/hongcai/lhc_data/nacos/data 22 | server: 222.201.187.50 23 | dnsPolicy: ClusterFirst 24 | restartPolicy: Always 25 | --- 26 | apiVersion: v1 27 | kind: Service 28 | metadata: 29 | labels: 30 | run: nacos 31 | name: nacos 32 | spec: 33 | ports: 34 | - port: 8848 35 | name: port1 36 | protocol: TCP 37 | nodePort: 30848 38 | targetPort: 8848 39 | - port: 9848 40 | name: port2 41 | protocol: TCP 42 | nodePort: 31848 43 | targetPort: 9848 44 | selector: 45 | run: nacos 46 | type: NodePort -------------------------------------------------------------------------------- /nacos/DEFAULT_GROUP/application-edge-computing.yaml: -------------------------------------------------------------------------------- 1 | edgeComputing: 2 | seed: 100 3 | edgeNodeSeed: 100 4 | taskSeed: 100 5 | reliabilitySeed: 100 6 | schedulerSeed: 100 7 | minTaskRate: 0 8 | maxTaskRate: 1.0 9 | cpuCapacity: 3 10 | minCpuCore: 4 11 | maxCpuCore: 32 12 | minTaskSize: 500 13 | maxTaskSize: 2500 14 | minTaskComplexity: 400 15 | maxTaskComplexity: 800 16 | deadline: 1000 17 | minTransmissionRate: 10 18 | maxTransmissionRate: 40 19 | minTransmissionFailureRate: 0 20 | maxTransmissionFailureRate: 0.03 21 | minExecutionFailureRate: 0 22 | maxExecutionFailureRate: 0.4 23 | timeSlot: 100 24 | queueCoef: 3.0 -------------------------------------------------------------------------------- /nacos/DEFAULT_GROUP/application-mappo.yaml: -------------------------------------------------------------------------------- 1 | edgeComputing: 2 | scheduler: rl 3 | flag: flag 4 | 5 | edgeNodeNumber: 10 6 | episodeLimit: 50 7 | episodeNumber: 50 8 | 9 | testNumber: 5 10 | testFrequency: 10 11 | 12 | taskSeed: 102 13 | reliabilitySeed: 102 14 | schedulerSeed: 102 15 | # runner: rl-test 16 | runner: rl-online 17 | rl: 18 | name: mappo 19 | learning-rate: 0.0003 20 | use-learning-rate-decay: false 21 | start-learning-rate: 0.0003 22 | end-learning-rate: 0.0 23 | use-clip-grad: true 24 | clip-grad-coef: 10 25 | use-normalized-reward: true 26 | gamma: 0.99 27 | use-gae: true 28 | gae-lambda: 0.95 29 | clip: 0.2 30 | use-entropy: false 31 | entropy-coef: 0.01 32 | hidden-shape: 64 33 | epoch: 4 34 | 35 | use-trained-model: false 36 | model-flag: model_flag 37 | 38 | buffer-size: 50 39 | batch-size: 50 40 | action-shape: 11 41 | state-shape: 52 42 | -------------------------------------------------------------------------------- /nacos/DEFAULT_GROUP/application-masac.yaml: -------------------------------------------------------------------------------- 1 | edgeComputing: 2 | scheduler: rl 3 | 4 | flag: flag 5 | 6 | edgeNodeNumber: 10 7 | episodeLimit: 50 8 | episodeNumber: 200 9 | 10 | testNumber: 5 11 | testFrequency: 10 12 | 13 | taskSeed: 102 14 | reliabilitySeed: 102 15 | schedulerSeed: 102 16 | runner: rl-test 17 | # runner: rl-offline 18 | rl: 19 | name: masac 20 | learning-rate: 0.0003 21 | hidden-shape: 64 22 | use-normalized-reward: true 23 | 24 | training-time: 400 25 | test-frequency: 100 26 | 27 | buffer-path: buffer_path 28 | use-cql: true 29 | cql-weight: 0.1 30 | use-soft-update: true 31 | tau: 0.005 32 | use-adaptive-alpha: true 33 | alpha: 0.2 34 | target-entropy-coef: 0.95 35 | # alpha: 0.05 36 | gamma: 0.99 37 | 38 | buffer-size: 10000 39 | batch-size: 128 40 | action-shape: 11 41 | state-shape: 52 42 | 43 | # offline-to-online 44 | use-addition-critic: true 45 | 46 | use-trained-model: true 47 | model-flag: model_flag -------------------------------------------------------------------------------- /nacos/DEFAULT_GROUP/application-random.yaml: -------------------------------------------------------------------------------- 1 | edgeComputing: 2 | scheduler: random 3 | flag: flag 4 | 5 | edgeNodeNumber: 10 6 | episodeLimit: 50 7 | episodeNumber: 200 8 | 9 | testNumber: 5 10 | testFrequency: 10 11 | 12 | # runner: heuristic-data 13 | runner: heuristic 14 | # runner: rl-offline 15 | # runner: rl-test 16 | 17 | taskSeed: 102 18 | reliabilitySeed: 102 19 | schedulerSeed: 102 20 | 21 | heuristic: 22 | name: random 23 | -------------------------------------------------------------------------------- /nacos/DEFAULT_GROUP/application-reactive.yaml: -------------------------------------------------------------------------------- 1 | edgeComputing: 2 | scheduler: reactive 3 | 4 | testNumber: 5 5 | testFrequency: 10 6 | 7 | flag: flag 8 | 9 | edgeNodeNumber: 10 10 | episodeLimit: 50 11 | episodeNumber: 5 12 | 13 | taskSeed: 102 14 | reliabilitySeed: 102 15 | schedulerSeed: 102 16 | runner: heuristic 17 | 18 | heuristic: 19 | name: reactive 20 | queue-coef: 1.0 21 | -------------------------------------------------------------------------------- /nacos/DEFAULT_GROUP/application-reliability-two-choice.yaml: -------------------------------------------------------------------------------- 1 | edgeComputing: 2 | scheduler: reliability-two-choice 3 | flag: flag 4 | edgeNodeNumber: 10 5 | episodeLimit: 50 6 | episodeNumber: 200 7 | 8 | testNumber: 5 9 | testFrequency: 10 10 | 11 | taskSeed: 102 12 | reliabilitySeed: 102 13 | schedulerSeed: 102 14 | runner: heuristic-data 15 | # runner: heuristic 16 | 17 | 18 | heuristic: 19 | name: reliability-two-choice 20 | 21 | rl: 22 | buffer-size: 10000 23 | batch-size: 64 24 | action-shape: 11 25 | state-shape: 52 -------------------------------------------------------------------------------- /nacos/DEFAULT_GROUP/application.yaml: -------------------------------------------------------------------------------- 1 | spring: 2 | datasource: 3 | type: com.alibaba.druid.pool.DruidDataSource 4 | driver-class-name: com.mysql.cj.jdbc.Driver 5 | url: jdbc:mysql://xxx.xxx.xxx.xxx:3306/edge_computing?useUnicode=true&characterEncoding=utf-8&useSSL=false&allowPublicKeyRetrieval=true&serverTimezone=GMT 6 | username: root 7 | password: 123456 8 | 9 | mybatis-plus: 10 | global-config: 11 | db-config: 12 | id-type: auto 13 | table-prefix: ec_ 14 | hadoop: 15 | hdfs: 16 | url: hdfs://xxx.xxx.xxx.xxx:9000 -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 4.0.0 6 | 7 | cn.edu.scut 8 | edge-computing 9 | pom 10 | 1.0-SNAPSHOT 11 | 12 | edge-node 13 | edge-controller 14 | edge-api 15 | edge-experiment 16 | edge-algorithm 17 | 18 | 19 | 20 | 17 21 | 17 22 | 23 | 24 | 25 | 26 | djl.ai 27 | https://oss.sonatype.org/content/repositories/snapshots/ 28 | 29 | 30 | 31 | 32 | 33 | 34 | org.springframework.boot 35 | spring-boot-dependencies 36 | 2.7.5 37 | pom 38 | import 39 | 40 | 41 | 42 | org.springframework.cloud 43 | spring-cloud-dependencies 44 | 2021.0.4 45 | pom 46 | import 47 | 48 | 49 | com.alibaba.cloud 50 | spring-cloud-alibaba-dependencies 51 | 2021.0.4.0 52 | pom 53 | import 54 | 55 | 56 | 57 | 58 | ai.djl 59 | bom 60 | 0.19.0 61 | pom 62 | import 63 | 64 | 65 | 66 | 67 | com.alibaba 68 | druid-spring-boot-starter 69 | 1.2.15 70 | 71 | 72 | com.baomidou 73 | mybatis-plus-boot-starter 74 | 3.5.2 75 | 76 | 77 | mysql 78 | mysql-connector-java 79 | 8.0.31 80 | 81 | 82 | org.apache.commons 83 | commons-math3 84 | 3.6.1 85 | 86 | 87 | org.apache.hadoop 88 | hadoop-common 89 | 3.3.4 90 | 91 | 92 | org.apache.hadoop 93 | hadoop-hdfs 94 | 3.3.4 95 | 96 | 97 | org.apache.hadoop 98 | hadoop-hdfs-client 99 | 3.3.4 100 | 101 | 102 | 103 | org.springframework.cloud 104 | spring-cloud-starter-bootstrap 105 | 3.1.5 106 | 107 | 108 | 109 | 110 | -------------------------------------------------------------------------------- /sql/create_table.sql: -------------------------------------------------------------------------------- 1 | create 2 | database edge_computing; 3 | use 4 | edge_computing; 5 | 6 | CREATE TABLE `ec_test_task` 7 | ( 8 | `id` bigint NOT NULL AUTO_INCREMENT, 9 | `source` varchar(20) DEFAULT NULL, 10 | `destination` varchar(20) DEFAULT NULL, 11 | `status` varchar(20) DEFAULT NULL, 12 | `task_size` bigint DEFAULT NULL, 13 | `task_complexity` bigint DEFAULT NULL, 14 | `cpu_cycle` bigint DEFAULT NULL, 15 | `deadline` bigint DEFAULT NULL, 16 | `transmission_waiting_time` bigint DEFAULT NULL, 17 | `transmission_time` bigint DEFAULT NULL, 18 | `execution_waiting_time` bigint DEFAULT NULL, 19 | `execution_time` bigint DEFAULT NULL, 20 | PRIMARY KEY (`id`) 21 | ) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci 22 | DEFAULT CHARSET = utf8mb4 23 | COLLATE = utf8mb4_0900_ai_ci; 24 | 25 | CREATE TABLE `ec_edge_node` 26 | ( 27 | `id` int NOT NULL AUTO_INCREMENT, 28 | `cpu_num` int DEFAULT NULL, 29 | `execution_failure_rate` double DEFAULT NULL, 30 | `task_rate` double DEFAULT NULL, 31 | `name` varchar(20) DEFAULT NULL, 32 | PRIMARY KEY (`id`) 33 | ) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci; 34 | 35 | 36 | CREATE TABLE `ec_link` 37 | ( 38 | `id` int NOT NULL AUTO_INCREMENT, 39 | `source` varchar(20) NOT NULL, 40 | `destination` varchar(20) NOT NULL, 41 | `transmission_rate` double NOT NULL, 42 | `transmission_failure_rate` double NOT NULL, 43 | PRIMARY KEY (`id`) 44 | ) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci; -------------------------------------------------------------------------------- /start-experiment.sh: -------------------------------------------------------------------------------- 1 | # maven 2 | mvn clean install -Dmaven.test.skip=true 3 | 4 | # docker 5 | ## edge-node 6 | cd ./edge-node || exit 7 | docker rmi edge-node:v1.0 8 | docker build -t edge-node:v1.0 . 9 | ## edge-controller 10 | cd ../edge-controller || exit 11 | docker rmi edge-controller:v1.0 12 | docker build -t edge-controller:v1.0 . 13 | ## edge-experiment 14 | cd ../edge-experiment || exit 15 | docker rmi edge-experiment:v1.0 16 | docker build -t edge-experiment:v1.0 . 17 | 18 | # kubernetes 19 | ## edge-node 20 | cd ../k8s/edge || exit 21 | for i in {1..10}; do 22 | export spring_application_name=edge-node-$i 23 | envsubst