├── .gitignore
├── edge-algorithm
├── README.md
├── pom.xml
└── src
│ └── main
│ └── java
│ └── cn
│ └── edu
│ └── scut
│ ├── agent
│ ├── IMAAgent.java
│ ├── MABuffer.java
│ ├── mappo
│ │ ├── MAPPOAgent.java
│ │ └── MAPPOConfig.java
│ └── masac
│ │ ├── AlphaBlock.java
│ │ ├── MASACAgent.java
│ │ └── MASACConfiguration.java
│ ├── config
│ └── DJLConfig.java
│ ├── service
│ └── TransitionService.java
│ └── util
│ ├── DJLUtils.java
│ ├── FileUtils.java
│ └── PlotUtils.java
├── edge-api
├── README.md
├── pom.xml
└── src
│ └── main
│ └── java
│ └── cn
│ └── edu
│ └── scut
│ ├── bean
│ ├── Constants.java
│ ├── EdgeNode.java
│ ├── Link.java
│ ├── RatcVo.java
│ ├── StoreConstants.java
│ ├── Task.java
│ └── TaskStatus.java
│ ├── mapper
│ ├── EdgeNodeMapper.java
│ ├── LinkMapper.java
│ └── TaskMapper.java
│ ├── service
│ ├── EdgeNodeService.java
│ ├── LinkService.java
│ └── TaskService.java
│ └── util
│ └── ArrayUtils.java
├── edge-controller
├── Dockerfile
├── README.md
├── pom.xml
└── src
│ └── main
│ ├── java
│ └── cn
│ │ └── edu
│ │ └── scut
│ │ ├── ControllerApp.java
│ │ ├── config
│ │ ├── DataGenerationConfig.java
│ │ └── SpringConfig.java
│ │ ├── controller
│ │ └── EdgeNodeController.java
│ │ ├── service
│ │ ├── EdgeConfigService.java
│ │ └── UserService.java
│ │ ├── thread
│ │ └── UserRunnable.java
│ │ └── utils
│ │ └── SpringBeanUtils.java
│ └── resources
│ ├── META-INF
│ └── MANIFEST.MF
│ └── bootstrap.yaml
├── edge-experiment
├── Dockerfile
├── README.md
├── pom.xml
├── src
│ └── main
│ │ ├── java
│ │ └── cn
│ │ │ └── edu
│ │ │ └── scut
│ │ │ ├── ExperimentApp.java
│ │ │ ├── config
│ │ │ ├── HadoopConfig.java
│ │ │ └── SpringConfig.java
│ │ │ ├── controller
│ │ │ └── ModelController.java
│ │ │ ├── runner
│ │ │ ├── IRunner.java
│ │ │ ├── OfflineMARLTrainingRunner.java
│ │ │ ├── OnlineHeuristicDataRunner.java
│ │ │ ├── OnlineHeuristicTestRunner.java
│ │ │ ├── OnlineMARLTestingRunner.java
│ │ │ └── OnlineMARLTrainingRunner.java
│ │ │ ├── service
│ │ │ ├── PlotService.java
│ │ │ └── RunnerService.java
│ │ │ └── util
│ │ │ ├── DateTimeUtils.java
│ │ │ ├── MathUtils.java
│ │ │ └── SpringBeanUtils.java
│ │ └── resources
│ │ └── bootstrap.yaml
└── test-results
│ └── buffer
│ └── test-buffer.txt
├── edge-node
├── Dockerfile
├── pom.xml
└── src
│ └── main
│ ├── java
│ └── cn
│ │ └── edu
│ │ └── scut
│ │ ├── EdgeNodeApp.java
│ │ ├── bean
│ │ └── EdgeNodeSystem.java
│ │ ├── config
│ │ ├── Config.java
│ │ └── HadoopConfig.java
│ │ ├── controller
│ │ ├── CmdController.java
│ │ ├── EdgeNodeController.java
│ │ ├── FileController.java
│ │ ├── ModelController.java
│ │ └── UserController.java
│ │ ├── queue
│ │ ├── ExecutionQueue.java
│ │ └── TransmissionQueue.java
│ │ ├── scheduler
│ │ ├── DRLScheduler.java
│ │ ├── IScheduler.java
│ │ ├── RandomScheduler.java
│ │ ├── ReactiveScheduler.java
│ │ └── ReliabilityTwoChoice.java
│ │ ├── service
│ │ └── EdgeNodeSystemService.java
│ │ ├── thread
│ │ ├── ExecutionRunnable.java
│ │ └── TransmissionRunnable.java
│ │ └── utils
│ │ └── SpringBeanUtils.java
│ └── resources
│ ├── META-INF
│ └── MANIFEST.MF
│ └── bootstrap.yaml
├── k8s
├── edge
│ ├── edge-controller.yaml
│ ├── edge-experiment.yaml
│ └── edge-node.yaml
├── mysql.yaml
└── nacos.yaml
├── nacos
└── DEFAULT_GROUP
│ ├── application-edge-computing.yaml
│ ├── application-mappo.yaml
│ ├── application-masac.yaml
│ ├── application-random.yaml
│ ├── application-reactive.yaml
│ ├── application-reliability-two-choice.yaml
│ └── application.yaml
├── pom.xml
├── sql
└── create_table.sql
├── start-experiment.sh
└── stop-experiment.sh
/.gitignore:
--------------------------------------------------------------------------------
1 | results
2 | .idea
3 | edge-algorithm/target
4 | edge-api/target
5 | edge-controller/target
6 | edge-experiment/target
7 | edge-node/target
8 | edge-test/target
--------------------------------------------------------------------------------
/edge-algorithm/README.md:
--------------------------------------------------------------------------------
1 | DRL algorithm.
2 |
--------------------------------------------------------------------------------
/edge-algorithm/pom.xml:
--------------------------------------------------------------------------------
1 |
2 |
5 |
6 | edge-computing
7 | cn.edu.scut
8 | 1.0-SNAPSHOT
9 |
10 | 4.0.0
11 |
12 | edge-algorithm
13 |
14 |
15 | 17
16 | 17
17 | UTF-8
18 |
19 |
20 |
21 |
22 | cn.edu.scut
23 | edge-api
24 | 1.0-SNAPSHOT
25 |
26 |
27 | org.springframework.boot
28 | spring-boot-starter
29 |
30 |
31 | org.projectlombok
32 | lombok
33 |
34 |
35 |
36 | ai.djl
37 | api
38 |
39 |
40 | ai.djl.pytorch
41 | pytorch-engine
42 | runtime
43 |
44 |
45 | ai.djl.pytorch
46 | pytorch-native-cu116
47 | linux-x86_64
48 | runtime
49 |
50 |
51 | ai.djl.pytorch
52 | pytorch-jni
53 | runtime
54 |
55 |
56 |
57 |
58 | tech.tablesaw
59 | tablesaw-core
60 | LATEST
61 |
62 |
63 | tech.tablesaw
64 | tablesaw-jsplot
65 | 0.43.1
66 |
67 |
68 |
69 |
70 |
71 | org.springframework.boot
72 | spring-boot-starter-test
73 | test
74 |
75 |
76 |
77 | org.springframework.boot
78 | spring-boot-starter-logging
79 |
80 |
81 |
--------------------------------------------------------------------------------
/edge-algorithm/src/main/java/cn/edu/scut/agent/IMAAgent.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut.agent;
2 |
3 | import java.io.InputStream;
4 |
5 | public interface IMAAgent {
6 |
7 | int selectAction(float[] state, int[] availAction, boolean training);
8 |
9 | void train();
10 |
11 | void saveModel(String flag);
12 |
13 | void loadModel(String flag);
14 |
15 | void saveHdfsModel(String flag);
16 |
17 | void loadHdfsModel(String flag);
18 |
19 | void loadSteamModel(InputStream inputStream, String fileName);
20 | }
21 |
--------------------------------------------------------------------------------
/edge-algorithm/src/main/java/cn/edu/scut/agent/MABuffer.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut.agent;
2 |
3 |
4 | import ai.djl.ndarray.NDList;
5 | import ai.djl.ndarray.NDManager;
6 | import ai.djl.ndarray.index.NDIndex;
7 | import ai.djl.ndarray.types.DataType;
8 | import ai.djl.ndarray.types.Shape;
9 | import cn.edu.scut.util.FileUtils;
10 | import lombok.Getter;
11 | import lombok.extern.slf4j.Slf4j;
12 | import org.apache.hadoop.fs.FileSystem;
13 | import org.apache.hadoop.fs.Path;
14 | import org.springframework.beans.factory.InitializingBean;
15 | import org.springframework.beans.factory.annotation.Autowired;
16 | import org.springframework.beans.factory.annotation.Value;
17 | import org.springframework.context.annotation.Lazy;
18 | import org.springframework.stereotype.Component;
19 |
20 | import java.io.IOException;
21 | import java.nio.file.Files;
22 | import java.nio.file.Paths;
23 | import java.util.ArrayList;
24 | import java.util.Collections;
25 | import java.util.Random;
26 |
27 | @Component
28 | @Lazy
29 | @Slf4j
30 | public class MABuffer implements InitializingBean {
31 |
32 | @Value("${rl.state-shape}")
33 | private int stateShape;
34 |
35 | @Value("${edgeComputing.edgeNodeNumber}")
36 | private int agentNumber;
37 |
38 | @Value("${rl.action-shape}")
39 | private int actionShape;
40 |
41 | @Value("${rl.buffer-size}")
42 | private int bufferSize;
43 |
44 | @Value("${rl.batch-size}")
45 | private int batchSize;
46 |
47 | @Value("${hadoop.hdfs.url}")
48 | private String hdfsUrl;
49 |
50 | @Getter
51 | private float[][][] states;
52 | @Getter
53 | private int[][][] actions;
54 | @Getter
55 | private float[][][] rewards;
56 | @Getter
57 | private int[][][] availActions;
58 | @Getter
59 | private float[][][] nextStates;
60 |
61 | private int index = 0;
62 |
63 | private int size = 0;
64 |
65 | @Autowired
66 | private FileSystem fileSystem;
67 |
68 | @Autowired
69 | private Random bufferRandom;
70 |
71 | @Override
72 | public void afterPropertiesSet() {
73 | log.info("buffer-size {} ", bufferSize);
74 | states = new float[bufferSize][agentNumber][stateShape];
75 | actions = new int[bufferSize][agentNumber][1];
76 | availActions = new int[bufferSize][agentNumber][actionShape];
77 | rewards = new float[bufferSize][agentNumber][1];
78 | nextStates = new float[bufferSize][agentNumber][stateShape];
79 | }
80 |
81 | public void insert(float[][] state, int[][] action, int[][] availAction, float[][] reward, float[][] nextState) {
82 | states[index] = state;
83 | actions[index] = action;
84 | availActions[index] = availAction;
85 | rewards[index] = reward;
86 | nextStates[index] = nextState;
87 | index = (index + 1) % bufferSize;
88 | size = Math.min(size + 1, bufferSize);
89 | }
90 |
91 | // off-policy
92 | public NDList sample(NDManager manager) {
93 | var list = new ArrayList();
94 | for (int i = 0; i < size; i++) {
95 | list.add(i);
96 | }
97 | Collections.shuffle(list, bufferRandom);
98 | var batch = new int[batchSize];
99 | for (int i = 0; i < batchSize; i++) {
100 | batch[i] = list.get(i);
101 | }
102 |
103 | var ndStates = manager.zeros(new Shape(batchSize, agentNumber, stateShape), DataType.FLOAT32);
104 | var ndActions = manager.zeros(new Shape(batchSize, agentNumber, 1), DataType.INT32);
105 | var ndAvailActions = manager.zeros(new Shape(batchSize, agentNumber, actionShape), DataType.INT32);
106 | var ndRewards = manager.zeros(new Shape(batchSize, agentNumber, 1), DataType.FLOAT32);
107 | var ndNextStates = manager.zeros(new Shape(batchSize, agentNumber, stateShape), DataType.FLOAT32);
108 |
109 | for (int i = 0; i < batchSize; i++) {
110 | var ndIndex = new NDIndex(i);
111 | ndStates.set(ndIndex, manager.create(states[batch[i]]));
112 | ndActions.set(ndIndex, manager.create(actions[batch[i]]));
113 | ndAvailActions.set(ndIndex, manager.create(availActions[batch[i]]));
114 | ndRewards.set(ndIndex, manager.create(rewards[batch[i]]));
115 | ndNextStates.set(ndIndex, manager.create(nextStates[batch[i]]));
116 | }
117 | return new NDList(ndStates, ndActions, ndAvailActions, ndRewards, ndNextStates);
118 | }
119 |
120 | // on-policy
121 | public NDList sampleAll(NDManager manager) {
122 | var ndStates = manager.zeros(new Shape(bufferSize, agentNumber, stateShape), DataType.FLOAT32);
123 | var ndActions = manager.zeros(new Shape(bufferSize, agentNumber, 1), DataType.INT32);
124 | var ndAvailActions = manager.zeros(new Shape(bufferSize, agentNumber, actionShape), DataType.INT32);
125 | var ndRewards = manager.zeros(new Shape(bufferSize, agentNumber, 1), DataType.FLOAT32);
126 | var ndNextStates = manager.zeros(new Shape(bufferSize, agentNumber, stateShape), DataType.FLOAT32);
127 | for (int i = 0; i < bufferSize; i++) {
128 | var ndIndex = new NDIndex(i);
129 | ndStates.set(ndIndex, manager.create(states[i]));
130 | ndActions.set(ndIndex, manager.create(actions[i]));
131 | ndAvailActions.set(ndIndex, manager.create(availActions[i]));
132 | ndRewards.set(ndIndex, manager.create(rewards[i]));
133 | ndNextStates.set(ndIndex, manager.create(nextStates[i]));
134 | }
135 | return new NDList(ndStates, ndActions, ndAvailActions, ndRewards, ndNextStates);
136 | }
137 |
138 | public void saveHdfs(String flag) {
139 | var path = Paths.get("results", "buffer", flag);
140 | try {
141 | Files.createDirectories(path);
142 | } catch (IOException e) {
143 | throw new RuntimeException(e);
144 | }
145 | String basePath = "results/buffer/" + flag + "/";
146 | try {
147 | FileUtils.writeObject(states, Paths.get("results", "buffer", flag, "states.array"));
148 | fileSystem.copyFromLocalFile(true, true, new Path(basePath + "states.array"), new Path(hdfsUrl + "/" + basePath + "states.array"));
149 |
150 | FileUtils.writeObject(actions, Paths.get("results", "buffer", flag, "actions.array"));
151 | fileSystem.copyFromLocalFile(true, true, new Path(basePath + "actions.array"), new Path(hdfsUrl + "/" + basePath + "actions.array"));
152 |
153 | FileUtils.writeObject(availActions, Paths.get("results", "buffer", flag, "availActions.array"));
154 | fileSystem.copyFromLocalFile(true, true, new Path(basePath + "availActions.array"), new Path(hdfsUrl + "/" + basePath + "availActions.array"));
155 |
156 | FileUtils.writeObject(rewards, Paths.get("results", "buffer", flag, "rewards.array"));
157 | fileSystem.copyFromLocalFile(true, true, new Path(basePath + "rewards.array"), new Path(hdfsUrl + "/" + basePath + "rewards.array"));
158 |
159 | FileUtils.writeObject(nextStates, Paths.get("results", "buffer", flag, "nextStates.array"));
160 | fileSystem.copyFromLocalFile(true, true, new Path(basePath + "nextStates.array"), new Path(hdfsUrl + "/" + basePath + "nextStates.array"));
161 | } catch (IOException e) {
162 | throw new RuntimeException(e);
163 | }
164 | }
165 |
166 | public void loadHdfs(String bufferPath) {
167 | var path = Paths.get("results", "buffer", bufferPath);
168 | try {
169 | Files.createDirectories(path);
170 | } catch (IOException e) {
171 | throw new RuntimeException(e);
172 | }
173 | String basePath = "results/buffer/" + bufferPath + "/";
174 | try {
175 | fileSystem.copyToLocalFile(false, new Path(hdfsUrl + "/" + basePath + "states.array"), new Path(basePath + "states.array"), false);
176 | states = (float[][][]) FileUtils.readObject(Paths.get("results", "buffer", bufferPath, "states.array"));
177 |
178 | fileSystem.copyToLocalFile(false, new Path(hdfsUrl + "/" + basePath + "actions.array"), new Path(basePath + "actions.array"), false);
179 | actions = (int[][][]) FileUtils.readObject(Paths.get("results", "buffer", bufferPath, "actions.array"));
180 |
181 | fileSystem.copyToLocalFile(false, new Path(hdfsUrl + "/" + basePath + "availActions.array"), new Path(basePath + "availActions.array"), false);
182 | availActions = (int[][][]) FileUtils.readObject(Paths.get("results", "buffer", bufferPath, "availActions.array"));
183 |
184 | fileSystem.copyToLocalFile(false, new Path(hdfsUrl + "/" + basePath + "rewards.array"), new Path(basePath + "rewards.array"), false);
185 | rewards = (float[][][]) FileUtils.readObject(Paths.get("results", "buffer", bufferPath, "rewards.array"));
186 |
187 | fileSystem.copyToLocalFile(false, new Path(hdfsUrl + "/" + basePath + "nextStates.array"), new Path(basePath + "nextStates.array"), false);
188 | nextStates = (float[][][]) FileUtils.readObject(Paths.get("results", "buffer", bufferPath, "nextStates.array"));
189 | } catch (IOException e) {
190 | throw new RuntimeException(e);
191 | }
192 | index = bufferSize;
193 | size = bufferSize;
194 | }
195 | }
196 |
--------------------------------------------------------------------------------
/edge-algorithm/src/main/java/cn/edu/scut/agent/mappo/MAPPOAgent.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut.agent.mappo;
2 |
3 | import ai.djl.MalformedModelException;
4 | import ai.djl.Model;
5 | import ai.djl.ndarray.NDArray;
6 | import ai.djl.ndarray.NDArrays;
7 | import ai.djl.ndarray.NDList;
8 | import ai.djl.ndarray.NDManager;
9 | import ai.djl.training.TrainingConfig;
10 | import ai.djl.translate.NoopTranslator;
11 | import ai.djl.translate.TranslateException;
12 | import cn.edu.scut.agent.IMAAgent;
13 | import cn.edu.scut.agent.MABuffer;
14 | import cn.edu.scut.util.DJLUtils;
15 | import cn.edu.scut.util.FileUtils;
16 | import lombok.extern.slf4j.Slf4j;
17 | import org.apache.hadoop.fs.FileSystem;
18 | import org.apache.hadoop.fs.Path;
19 | import org.springframework.beans.factory.InitializingBean;
20 | import org.springframework.beans.factory.annotation.Autowired;
21 | import org.springframework.beans.factory.annotation.Value;
22 | import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
23 | import org.springframework.stereotype.Component;
24 |
25 | import java.io.IOException;
26 | import java.io.InputStream;
27 | import java.nio.file.Paths;
28 | import java.util.Random;
29 | import java.util.concurrent.locks.Lock;
30 | import java.util.concurrent.locks.ReentrantLock;
31 |
32 | @Component
33 | @Slf4j
34 | @ConditionalOnProperty(name = "rl.name", havingValue = "mappo")
35 | public class MAPPOAgent implements IMAAgent, InitializingBean {
36 |
37 | @Autowired
38 | private Random schedulerRandom;
39 |
40 | @Autowired
41 | private NDManager manager;
42 |
43 | @Autowired
44 | private MABuffer buffer;
45 |
46 | @Value("${rl.use-normalized-reward}")
47 | private boolean useNormalizedReward;
48 |
49 | @Value("${rl.epoch}")
50 | private int epoch;
51 |
52 | @Value("${rl.clip}")
53 | private float clip;
54 |
55 | @Value("${rl.use-entropy}")
56 | private boolean useEntropy;
57 |
58 | @Value("${rl.entropy-coef}")
59 | private float entropyCoef;
60 |
61 | @Value("${rl.gamma}")
62 | private float gamma;
63 |
64 | @Value("${edgeComputing.episodeLimit}")
65 | private int episodeLimit;
66 |
67 | @Value("${rl.use-gae}")
68 | private boolean useGae;
69 |
70 | @Value("${rl.gae-lambda}")
71 | private float gaeLambda;
72 |
73 | @Value("${hadoop.hdfs.url}")
74 | private String hdfsUrl;
75 |
76 | @Value("${spring.application.name}")
77 | String name;
78 |
79 | @Autowired
80 | private Model actorModel;
81 |
82 | @Autowired
83 | private Model criticModel;
84 |
85 | @Autowired
86 | private FileSystem fileSystem;
87 |
88 | @Autowired
89 | private TrainingConfig trainingConfig;
90 |
91 | private final Lock lock = new ReentrantLock();
92 |
93 | @Override
94 | public void afterPropertiesSet() {
95 | }
96 |
97 | public int selectAction(float[] state, int[] availAction, boolean training) {
98 | int action;
99 | try {
100 | lock.lock();
101 | var subManager = manager.newSubManager();
102 | try (subManager) {
103 | var predictor = actorModel.newPredictor(new NoopTranslator());
104 | try (predictor) {
105 | NDArray out = null;
106 | try {
107 | out = predictor.predict(new NDList(subManager.create(state))).singletonOrThrow();
108 | log.info("{}", out);
109 | } catch (TranslateException e) {
110 | log.error("predict error: {}", e.getMessage());
111 | }
112 | var bool = subManager.create(availAction).eq(0);
113 | assert out != null;
114 | out.set(bool, -1e8f);
115 | var prob = out.softmax(-1);
116 | log.info("prob: {}", prob);
117 | action = DJLUtils.sampleMultinomial(schedulerRandom, prob);
118 | }
119 | }
120 | } finally {
121 | lock.unlock();
122 | }
123 | return action;
124 | }
125 |
126 | public void train() {
127 | try {
128 | lock.lock();
129 | var subManager = manager.newSubManager();
130 | try (subManager) {
131 | var list = buffer.sampleAll(subManager);
132 | var states = list.get(0);
133 | var actions = list.get(1);
134 | var availActions = list.get(2);
135 | var rewards = list.get(3);
136 | var nextStates = list.get(4);
137 | // Trick: normalized reward
138 | if (useNormalizedReward) {
139 | var mean = rewards.mean();
140 | var std = rewards.sub(mean).pow(2).mean().sqrt().add(1e-5f);
141 | rewards = rewards.sub(mean).div(std);
142 | }
143 | var criticTrainer = criticModel.newTrainer(trainingConfig);
144 | var actorTrainer = actorModel.newTrainer(trainingConfig);
145 | try (criticTrainer; actorTrainer) {
146 | var stateValues = criticTrainer.evaluate(new NDList(states)).singletonOrThrow();
147 | var nextStateValues = criticTrainer.evaluate(new NDList(nextStates)).singletonOrThrow();
148 | // gae
149 | NDArray advantages;
150 | if (useGae) {
151 | advantages = DJLUtils.getGae(rewards, stateValues, nextStateValues, subManager, gamma, episodeLimit, gaeLambda);
152 | } else {
153 | var returns = DJLUtils.getReturn(rewards, nextStateValues, gamma, episodeLimit);
154 | advantages = returns.sub(stateValues);
155 | }
156 | var targets = advantages.add(stateValues);
157 | var out = actorTrainer.evaluate(new NDList(states)).singletonOrThrow();
158 | var index1 = availActions.eq(0);
159 | out.set(index1, -1e5f);
160 | var logProb = out.logSoftmax(-1);
161 | var oldLogProbTaken = logProb.gather(actions, -1);
162 | for (int i = 0; i < epoch; i++) {
163 | var gradientCollector = actorTrainer.newGradientCollector();
164 | try (gradientCollector) {
165 | var output = actorTrainer.forward(new NDList(states)).singletonOrThrow();
166 | var index = availActions.eq(0);
167 | output.set(index, -1e5f);
168 | var probabilities = output.softmax(-1);
169 | // attention!!!
170 | var logProbabilities = output.logSoftmax(-1);
171 | var logProbTaken = logProbabilities.gather(actions, -1);
172 | // PPO
173 | var ratios = logProbTaken.sub(oldLogProbTaken).exp();
174 | //
175 | var surr1 = ratios.mul(advantages);
176 | var surr2 = ratios.clip(1 - clip, 1 + clip).mul(advantages);
177 | // maximizing
178 | var loss = NDArrays.minimum(surr1, surr2).mean().neg();
179 | // Trick: entropy
180 | if (useEntropy) {
181 | // entropy definition:
182 | var entropy = probabilities.mul(logProbabilities).neg();
183 | // maximizing entropy
184 | var entropyLoss = entropy.mean().mul(entropyCoef).neg();
185 | loss = loss.add(entropyLoss);
186 | }
187 | gradientCollector.backward(loss);
188 | actorTrainer.step();
189 | }
190 | var criticGradientCollector = criticTrainer.newGradientCollector();
191 | try (criticGradientCollector) {
192 | var newValues = criticTrainer.forward(new NDList(states)).singletonOrThrow();
193 | var criticLoss = newValues.sub(targets).pow(2).mean().mul(0.5f);
194 | criticGradientCollector.backward(criticLoss);
195 | criticTrainer.step();
196 | }
197 | }
198 | }
199 | }
200 | } finally {
201 | lock.unlock();
202 | }
203 | }
204 |
205 | public void saveModel(String flag) {
206 | String basePath = "results/model/" + flag;
207 | try {
208 | actorModel.save(Paths.get(basePath), null);
209 | criticModel.save(Paths.get(basePath), null);
210 | } catch (IOException e) {
211 | log.info("save model: {}", e.getMessage());
212 | }
213 | }
214 |
215 | public void loadModel(String flag) {
216 | String basePath = "results/model/" + flag;
217 | try {
218 | actorModel.load(Paths.get(basePath));
219 | actorModel.load(Paths.get(basePath));
220 | } catch (IOException | MalformedModelException e) {
221 | log.info("load model: {}", e.getMessage());
222 | }
223 | }
224 |
225 | public void saveHdfsModel(String flag) {
226 | String basePath = "results/model/" + flag;
227 | try {
228 | actorModel.save(Paths.get(basePath), null);
229 | criticModel.save(Paths.get(basePath), null);
230 | } catch (IOException e) {
231 | log.error("save model error: {}", e.getMessage());
232 | }
233 | var actorPath = basePath + "/actor-0000.params";
234 | var criticPath = basePath + "/critic-0000.params";
235 | try {
236 | if (fileSystem.exists(new Path(hdfsUrl + "/" + basePath))) {
237 | fileSystem.delete(new Path(hdfsUrl + "/" + basePath), true);
238 | }
239 | } catch (IOException e) {
240 | log.error("delete file in file system error: {}", e.getMessage());
241 | }
242 | try {
243 | fileSystem.copyFromLocalFile(true, true, new Path(actorPath), new Path(hdfsUrl + "/" + actorPath));
244 | fileSystem.copyFromLocalFile(true, true, new Path(criticPath), new Path(hdfsUrl + "/" + criticPath));
245 | } catch (IOException e) {
246 | log.error("file system save file error: {}", e.getMessage());
247 | }
248 | }
249 |
250 | public void loadHdfsModel(String flag) {
251 | String basePath = "results/model/" + flag;
252 | try {
253 | var actorPath = basePath + "/actor-0000.params";
254 | var criticPath = basePath + "/critic-0000.params";
255 | fileSystem.copyToLocalFile(false, new Path(hdfsUrl + "/" + actorPath), new Path(actorPath), true);
256 | fileSystem.copyToLocalFile(false, new Path(hdfsUrl + "/" + criticPath), new Path(criticPath), true);
257 | } catch (IOException e) {
258 | log.info("load hdfs model: {}", e.getMessage());
259 | }
260 | try {
261 | actorModel.load(Paths.get(basePath));
262 | actorModel.load(Paths.get(basePath));
263 | } catch (IOException | MalformedModelException e) {
264 | throw new RuntimeException(e);
265 | }
266 | var parameters = actorModel.getBlock().getParameters();
267 | for (String key : parameters.keys()) {
268 | parameters.get(key).getArray().setRequiresGradient(true);
269 | }
270 | var parameters2 = criticModel.getBlock().getParameters();
271 | for (String key : parameters2.keys()) {
272 | parameters.get(key).getArray().setRequiresGradient(true);
273 | }
274 | FileUtils.recursiveDelete(Paths.get(basePath).toFile());
275 | }
276 |
277 | public void loadSteamModel(InputStream inputStream, String fileName) {
278 | switch (fileName) {
279 | case "actor.param" -> {
280 | try {
281 | actorModel.load(inputStream);
282 | } catch (IOException | MalformedModelException e) {
283 | log.info("load stream model error: {}", e.getMessage());
284 | }
285 | }
286 | case "critic.param" -> {
287 | try {
288 | criticModel.load(inputStream);
289 | } catch (IOException | MalformedModelException e) {
290 | log.info("load stream model error: {}", e.getMessage());
291 | }
292 | }
293 | default -> throw new RuntimeException("fileName error!");
294 | }
295 | }
296 | }
297 |
--------------------------------------------------------------------------------
/edge-algorithm/src/main/java/cn/edu/scut/agent/mappo/MAPPOConfig.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut.agent.mappo;
2 |
3 | import ai.djl.Model;
4 | import ai.djl.ndarray.NDArray;
5 | import ai.djl.ndarray.NDList;
6 | import ai.djl.ndarray.NDManager;
7 | import ai.djl.ndarray.types.DataType;
8 | import ai.djl.ndarray.types.Shape;
9 | import ai.djl.nn.Activation;
10 | import ai.djl.nn.Block;
11 | import ai.djl.nn.SequentialBlock;
12 | import ai.djl.nn.core.Linear;
13 | import ai.djl.training.DefaultTrainingConfig;
14 | import ai.djl.training.TrainingConfig;
15 | import ai.djl.training.loss.Loss;
16 | import ai.djl.training.optimizer.Optimizer;
17 | import ai.djl.training.tracker.PolynomialDecayTracker;
18 | import ai.djl.training.tracker.Tracker;
19 | import org.springframework.beans.factory.annotation.Value;
20 | import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
21 | import org.springframework.context.annotation.Bean;
22 | import org.springframework.context.annotation.Configuration;
23 |
24 | @Configuration
25 | @ConditionalOnProperty(name = "rl.name", havingValue = "mappo")
26 | public class MAPPOConfig {
27 | @Value("${rl.epoch}")
28 | private int epoch;
29 |
30 | @Value("${rl.start-learning-rate}")
31 | private float startLearningRate;
32 |
33 | @Value("${rl.end-learning-rate}")
34 | private float endLearningRate;
35 |
36 | @Value("${rl.use-learning-rate-decay}")
37 | private boolean useLearningRateDecay;
38 |
39 | @Value("${rl.learning-rate}")
40 | private float learningRate;
41 |
42 | @Value("${rl.use-clip-grad}")
43 | private boolean useClipGrad;
44 |
45 | @Value("${rl.clip-grad-coef}")
46 | private float clipGradCoef;
47 |
48 | @Value("${rl.hidden-shape}")
49 | private int hiddenShape;
50 |
51 | @Value("${rl.action-shape}")
52 | private int actionShape;
53 |
54 | @Value("${rl.state-shape}")
55 | private int stateShape;
56 |
57 | @Value("${edgeComputing.episodeNumber}")
58 | private int episodeNumber;
59 |
60 | @Value("${rl.batch-size}")
61 | private int batchSize;
62 |
63 | @Value("${edgeComputing.edgeNodeNumber}")
64 | private int agentNumber;
65 |
66 | @Bean
67 | public Tracker polynomialDecayTracker() {
68 | int trainNum = episodeNumber * epoch * 2;
69 | return PolynomialDecayTracker.builder()
70 | .setBaseValue(startLearningRate)
71 | .setEndLearningRate(endLearningRate)
72 | .setDecaySteps(trainNum) // share by actor and critic
73 | .build();
74 | }
75 |
76 | @Bean
77 | public Optimizer optimizer(Tracker polynomialDecayTracker) {
78 | var adam = Optimizer.adam();
79 | if (useLearningRateDecay) {
80 | adam.optLearningRateTracker(polynomialDecayTracker);
81 | } else {
82 | adam.optLearningRateTracker(Tracker.fixed(learningRate));
83 | }
84 | if (useClipGrad) {
85 | adam.optClipGrad(clipGradCoef);
86 | }
87 | return adam.build();
88 | }
89 |
90 | public Block createNetwork(NDManager manager, int outputDim) {
91 | var block = new SequentialBlock();
92 | block.add(Linear.builder().setUnits(hiddenShape).build());
93 | block.add(Activation::relu);
94 | block.add(Linear.builder().setUnits(hiddenShape).build());
95 | block.add(Activation::relu);
96 | block.add(Linear.builder().setUnits(outputDim).build());
97 | block.initialize(manager, DataType.FLOAT32, new Shape(batchSize, agentNumber, stateShape));
98 | return block;
99 | }
100 |
101 | @Bean
102 | public Model actorModel(NDManager manager) {
103 | var model = Model.newInstance("actor");
104 | model.setBlock(createNetwork(manager, actionShape));
105 | return model;
106 | }
107 |
108 | @Bean
109 | public Model criticModel(NDManager manager) {
110 | var model = Model.newInstance("critic");
111 | model.setBlock(createNetwork(manager, 1));
112 | return model;
113 | }
114 |
115 | @Bean
116 | public Loss loss() {
117 | return new Loss("null") {
118 | @Override
119 | public NDArray evaluate(NDList ndList, NDList ndList1) {
120 | return null;
121 | }
122 | };
123 | }
124 |
125 | @Bean
126 | public TrainingConfig trainingConfig(Optimizer optimizer, Loss loss) {
127 | return new DefaultTrainingConfig(loss)
128 | .optOptimizer(optimizer);
129 | }
130 | }
131 |
--------------------------------------------------------------------------------
/edge-algorithm/src/main/java/cn/edu/scut/agent/masac/AlphaBlock.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut.agent.masac;
2 |
3 | import ai.djl.ndarray.NDList;
4 | import ai.djl.ndarray.types.Shape;
5 | import ai.djl.nn.AbstractBlock;
6 | import ai.djl.nn.Parameter;
7 | import ai.djl.training.ParameterStore;
8 | import ai.djl.training.initializer.ConstantInitializer;
9 | import ai.djl.util.PairList;
10 |
11 | public class AlphaBlock extends AbstractBlock {
12 |
13 | private Parameter logAlpha;
14 |
15 | public AlphaBlock(float initAlpha) {
16 | logAlpha = addParameter(Parameter.builder()
17 | .setName("alpha")
18 | .setType(Parameter.Type.WEIGHT)
19 | .optShape(new Shape(1))
20 | .optInitializer(new ConstantInitializer((float) Math.log(initAlpha)))
21 | .optRequiresGrad(true)
22 | .build());
23 | }
24 |
25 | @Override
26 | protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, PairList params) {
27 | return new NDList(logAlpha.getArray());
28 | }
29 |
30 | @Override
31 | public Shape[] getOutputShapes(Shape[] inputShapes) {
32 | return new Shape[]{new Shape(1)};
33 | }
34 | }
35 |
--------------------------------------------------------------------------------
/edge-algorithm/src/main/java/cn/edu/scut/agent/masac/MASACAgent.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut.agent.masac;
2 |
3 | import ai.djl.MalformedModelException;
4 | import ai.djl.Model;
5 | import ai.djl.ndarray.NDArray;
6 | import ai.djl.ndarray.NDArrays;
7 | import ai.djl.ndarray.NDList;
8 | import ai.djl.ndarray.NDManager;
9 | import ai.djl.training.TrainingConfig;
10 | import ai.djl.translate.NoopTranslator;
11 | import ai.djl.translate.TranslateException;
12 | import cn.edu.scut.agent.IMAAgent;
13 | import cn.edu.scut.agent.MABuffer;
14 | import cn.edu.scut.util.DJLUtils;
15 | import cn.edu.scut.util.FileUtils;
16 | import lombok.extern.slf4j.Slf4j;
17 | import org.apache.hadoop.fs.FileSystem;
18 | import org.apache.hadoop.fs.Path;
19 | import org.springframework.beans.factory.InitializingBean;
20 | import org.springframework.beans.factory.annotation.Autowired;
21 | import org.springframework.beans.factory.annotation.Value;
22 | import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
23 | import org.springframework.stereotype.Component;
24 |
25 | import java.io.IOException;
26 | import java.io.InputStream;
27 | import java.nio.file.Files;
28 | import java.nio.file.Paths;
29 | import java.util.Random;
30 |
31 | @Component
32 | @Slf4j
33 | @ConditionalOnProperty(name = "rl.name", havingValue = "masac")
34 | public class MASACAgent implements InitializingBean, IMAAgent {
35 |
36 | @Value("${rl.action-shape}")
37 | private int actionShape;
38 |
39 | @Value("${rl.alpha}")
40 | private float alpha;
41 |
42 | @Value("${rl.gamma}")
43 | private float gamma;
44 |
45 | @Value("${rl.use-soft-update}")
46 | private boolean useSoftUpdate;
47 |
48 | @Value("${rl.tau}")
49 | private float tau;
50 |
51 | @Value("${rl.use-adaptive-alpha}")
52 | private boolean useAdaptiveAlpha;
53 |
54 | @Value("${rl.use-normalized-reward}")
55 | private boolean useNormalizedReward;
56 |
57 | @Autowired
58 | private Random schedulerRandom;
59 |
60 | @Autowired
61 | private Model q1Model;
62 | @Autowired
63 | private Model q2Model;
64 |
65 | @Autowired
66 | private Model targetQ1Model;
67 | @Autowired
68 | private Model targetQ2Model;
69 |
70 | @Autowired
71 | private Model actorModel;
72 |
73 | @Autowired
74 | private Model criticModel;
75 |
76 | @Autowired
77 | private MABuffer buffer;
78 |
79 | @Autowired
80 | private NDManager manager;
81 |
82 | private float targetEntropy;
83 |
84 | @Autowired
85 | private Model alphaModel;
86 |
87 | @Value("${spring.application.name}")
88 | String name;
89 |
90 | @Autowired
91 | private FileSystem fileSystem;
92 |
93 | @Value("${hadoop.hdfs.url}")
94 | private String hdfsUrl;
95 |
96 | @Value("${rl.use-cql}")
97 | private boolean useCql;
98 |
99 | @Value("${rl.cql-weight}")
100 | private float cqlWeight;
101 |
102 | @Value("${rl.use-addition-critic}")
103 | private boolean useAdditionCritic;
104 |
105 | @Value("${rl.target-entropy-coef}")
106 | private float targetEntropyCoef;
107 |
108 | @Autowired
109 | private TrainingConfig trainingConfig;
110 |
111 | @Override
112 | public void afterPropertiesSet() throws Exception {
113 | // TARGET_ENTROPY = (float) Math.log(1.0 / ACTION_SHAPE) * 0.98f;
114 | // We use 0.6 because the recommended 0.98 will cause alpha explosion.
115 | targetEntropy = -(float) Math.log(1.0 / actionShape) * targetEntropyCoef;
116 | }
117 |
118 | public int selectAction(float[] state, int[] availAction, boolean training) {
119 | int action = 0;
120 | var subManager = manager.newSubManager();
121 | try (subManager) {
122 | var predictor = actorModel.newPredictor(new NoopTranslator());
123 | try (predictor) {
124 | try {
125 | var out = predictor.predict(new NDList(subManager.create(state))).singletonOrThrow();
126 | var bool = subManager.create(availAction).eq(0);
127 | out.set(bool, -1e8f);
128 | var prob = out.softmax(-1);
129 | action = DJLUtils.sampleMultinomial(schedulerRandom, prob);
130 | } catch (TranslateException e) {
131 | log.error("predict error: {}", e.getMessage());
132 | }
133 | }
134 | }
135 | return action;
136 | }
137 |
138 | public void train() {
139 | var subManager = manager.newSubManager();
140 | try (subManager) {
141 | NDList list = buffer.sample(subManager);
142 | var states = list.get(0);
143 | var actions = list.get(1);
144 | var availActions = list.get(2);
145 | var rewards = list.get(3);
146 | var nextStates = list.get(4);
147 |
148 | if (useNormalizedReward) {
149 | var mean = rewards.mean();
150 | var std = rewards.sub(mean).pow(2).mean().sqrt().add(1e-5f);
151 | rewards = rewards.sub(mean).div(std);
152 | }
153 |
154 | var q1Trainer = q1Model.newTrainer(trainingConfig);
155 | var q2Trainer = q2Model.newTrainer(trainingConfig);
156 | var actorTrainer = actorModel.newTrainer(trainingConfig);
157 | var criticTrainer = criticModel.newTrainer(trainingConfig);
158 | var targetQ1Predictor = targetQ1Model.newPredictor(new NoopTranslator());
159 | var targetQ2Predictor = targetQ2Model.newPredictor(new NoopTranslator());
160 | var alphaTrainer = alphaModel.newTrainer(trainingConfig);
161 | NDArray alphaValue;
162 | if (useAdaptiveAlpha) {
163 | alphaValue = alphaModel.getBlock().getParameters().get("alpha").getArray().duplicate().exp();
164 | } else {
165 | alphaValue = subManager.create(alpha);
166 | }
167 | try (q1Trainer; q2Trainer; actorTrainer; criticTrainer; targetQ1Predictor; targetQ2Predictor; alphaTrainer) {
168 | var actorOut = actorTrainer.evaluate(new NDList(nextStates)).singletonOrThrow();
169 | var nextLogProbabilities = actorOut.logSoftmax(-1);
170 | NDArray nextTargetQ1;
171 | NDArray nextTargetQ2;
172 | try {
173 | nextTargetQ1 = targetQ1Predictor.predict(new NDList(nextStates)).singletonOrThrow();
174 | nextTargetQ2 = targetQ2Predictor.predict(new NDList(nextStates)).singletonOrThrow();
175 | } catch (TranslateException e) {
176 | throw new RuntimeException(e);
177 | }
178 | var targetQ = nextLogProbabilities.exp().mul(NDArrays.minimum(nextTargetQ1, nextTargetQ2).sub(nextLogProbabilities.mul(alphaValue)));
179 | var avgTargetQ = targetQ.sum(new int[]{-1}, true);
180 | var target = rewards.add(avgTargetQ.mul(gamma));
181 | var q1Value = q1Trainer.evaluate(new NDList(states)).singletonOrThrow();
182 | var q2Value = q2Trainer.evaluate(new NDList(states)).singletonOrThrow();
183 | var lobProbabilityValue = actorTrainer.evaluate(new NDList(states)).singletonOrThrow().logSoftmax(-1);
184 |
185 | if (useAdaptiveAlpha) {
186 | var alphaGradientCollector = alphaTrainer.newGradientCollector();
187 | try (alphaGradientCollector) {
188 | var entropy = lobProbabilityValue.exp().mul(lobProbabilityValue).sum(new int[]{-1}).mean().neg();
189 | var logAlpha = alphaModel.getBlock().getParameters().get("alpha").getArray();
190 | var loss = logAlpha.mul(entropy.sub(targetEntropy));
191 | alphaGradientCollector.backward(loss);
192 | alphaTrainer.step();
193 | }
194 | }
195 |
196 | var actorGradientCollector = actorTrainer.newGradientCollector();
197 | try (actorGradientCollector) {
198 | var qMin = NDArrays.minimum(q1Value, q2Value);
199 | var lobProbabilities = actorTrainer.forward(new NDList(states)).singletonOrThrow().logSoftmax(-1);
200 | var loss = lobProbabilities.exp().mul(qMin.sub(lobProbabilities.mul(alphaValue))).sum(new int[]{-1}).mean().neg();
201 | actorGradientCollector.backward(loss);
202 | actorTrainer.step();
203 | }
204 |
205 | var q1GradientCollector = q1Trainer.newGradientCollector();
206 | try (q1GradientCollector) {
207 | var q1 = q1Trainer.forward(new NDList(states)).singletonOrThrow();
208 | var q1Action = q1.gather(actions, -1);
209 | var loss1 = q1Action.sub(target).pow(2).mean();
210 | if (useCql) {
211 | var cqlLoss1 = (q1.exp().sum().log().mean()).sub(q1Action.mean());
212 | loss1.add(cqlLoss1.mul(cqlWeight));
213 | }
214 | q1GradientCollector.backward(loss1);
215 | q1Trainer.step();
216 | }
217 | var q2GradientCollector = q2Trainer.newGradientCollector();
218 | try (q2GradientCollector) {
219 | var q2 = q2Trainer.forward(new NDList(states)).singletonOrThrow();
220 | var q2Action = q2.gather(actions, -1);
221 | var loss2 = q2Action.sub(target).pow(2).mean();
222 | if (useCql) {
223 | var cqlLoss2 = (q2.exp().sum().log().mean()).sub(q2Action.mean());
224 | loss2.add(cqlLoss2.mul(cqlWeight));
225 | }
226 | q2GradientCollector.backward(loss2);
227 | q2Trainer.step();
228 | }
229 | if (useAdditionCritic) {
230 | var criticGradientCollector = criticTrainer.newGradientCollector();
231 | try (criticGradientCollector) {
232 | var value = criticTrainer.evaluate(new NDList(states)).singletonOrThrow();
233 | var loss = value.sub(target).pow(2).mean();
234 | criticGradientCollector.backward(loss);
235 | criticTrainer.step();
236 | }
237 | }
238 | }
239 | if (useSoftUpdate) {
240 | DJLUtils.softUpdate(q1Model.getBlock(), targetQ1Model.getBlock(), tau);
241 | DJLUtils.softUpdate(q1Model.getBlock(), targetQ2Model.getBlock(), tau);
242 | }
243 | }
244 | }
245 |
246 | public void saveModel(String path) {
247 | String basePath = "results/model/" + path + "/";
248 | var actorPath = basePath + "actor.param";
249 | var criticPath = basePath + "critic.param";
250 | try {
251 | Files.createDirectories(Paths.get(actorPath).getParent());
252 | } catch (IOException e) {
253 | throw new RuntimeException(e);
254 | }
255 | DJLUtils.saveModel(Paths.get(actorPath), actorModel.getBlock());
256 | DJLUtils.saveModel(Paths.get(criticPath), criticModel.getBlock());
257 | }
258 |
259 | public void loadModel(String path) {
260 | String basePath = "results/model/" + path + "/";
261 | var actorPath = basePath + "actor.param";
262 | var criticPath = basePath + "critic.param";
263 | DJLUtils.loadModel(Paths.get(actorPath), actorModel.getBlock(), manager);
264 | DJLUtils.loadModel(Paths.get(criticPath), criticModel.getBlock(), manager);
265 | }
266 |
267 | // edge-experiment
268 | @Override
269 | public void saveHdfsModel(String flag) {
270 | String basePath = "results/model/" + flag;
271 | try {
272 | actorModel.save(Paths.get(basePath), null);
273 | criticModel.save(Paths.get(basePath), null);
274 | } catch (IOException e) {
275 | log.error("save model error: {}", e.getMessage());
276 | }
277 | var actorPath = basePath + "/actor-0000.params";
278 | var criticPath = basePath + "/critic-0000.params";
279 | try {
280 | if (fileSystem.exists(new Path(hdfsUrl + "/" + basePath))) {
281 | fileSystem.delete(new Path(hdfsUrl + "/" + basePath), true);
282 | }
283 | } catch (IOException e) {
284 | log.error("delete file in file system error: {}", e.getMessage());
285 | }
286 | try {
287 | fileSystem.copyFromLocalFile(true, true, new Path(actorPath), new Path(hdfsUrl + "/" + actorPath));
288 | fileSystem.copyFromLocalFile(true, true, new Path(criticPath), new Path(hdfsUrl + "/" + criticPath));
289 | } catch (IOException e) {
290 | log.error("file system save file error: {}", e.getMessage());
291 | }
292 | }
293 |
294 | // edge-node, edge-experiment
295 | @Override
296 | public void loadHdfsModel(String flag) {
297 | String basePath = "results/model/" + flag;
298 | try {
299 | var actorPath = basePath + "/actor-0000.params";
300 | var criticPath = basePath + "/critic-0000.params";
301 | fileSystem.copyToLocalFile(false, new Path(hdfsUrl + "/" + actorPath), new Path(actorPath), true);
302 | fileSystem.copyToLocalFile(false, new Path(hdfsUrl + "/" + criticPath), new Path(criticPath), true);
303 | } catch (IOException e) {
304 | log.info("load hdfs model: {}", e.getMessage());
305 | }
306 | try {
307 | actorModel.load(Paths.get(basePath));
308 | actorModel.load(Paths.get(basePath));
309 | } catch (IOException | MalformedModelException e) {
310 | throw new RuntimeException(e);
311 | }
312 | FileUtils.recursiveDelete(Paths.get(basePath).toFile());
313 | }
314 |
315 | @Override
316 | public void loadSteamModel(InputStream inputStream, String fileName) {
317 | switch (fileName) {
318 | case "actor.param" -> {
319 | try {
320 | actorModel.load(inputStream);
321 | } catch (IOException | MalformedModelException e) {
322 | log.info("load stream model error: {}", e.getMessage());
323 | }
324 | }
325 | case "critic.param" -> {
326 | try {
327 | criticModel.load(inputStream);
328 | } catch (IOException | MalformedModelException e) {
329 | log.info("load stream model error: {}", e.getMessage());
330 | }
331 | }
332 | default -> throw new RuntimeException("fileName error!");
333 | }
334 | }
335 | }
336 |
--------------------------------------------------------------------------------
/edge-algorithm/src/main/java/cn/edu/scut/agent/masac/MASACConfiguration.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut.agent.masac;
2 |
3 | import ai.djl.Model;
4 | import ai.djl.ndarray.NDArray;
5 | import ai.djl.ndarray.NDList;
6 | import ai.djl.ndarray.NDManager;
7 | import ai.djl.ndarray.types.DataType;
8 | import ai.djl.ndarray.types.Shape;
9 | import ai.djl.nn.Activation;
10 | import ai.djl.nn.Block;
11 | import ai.djl.nn.SequentialBlock;
12 | import ai.djl.nn.core.Linear;
13 | import ai.djl.training.DefaultTrainingConfig;
14 | import ai.djl.training.TrainingConfig;
15 | import ai.djl.training.loss.Loss;
16 | import ai.djl.training.optimizer.Optimizer;
17 | import ai.djl.training.tracker.Tracker;
18 | import lombok.extern.slf4j.Slf4j;
19 | import org.springframework.beans.factory.annotation.Value;
20 | import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
21 | import org.springframework.context.annotation.Bean;
22 | import org.springframework.context.annotation.Configuration;
23 |
24 | @Configuration
25 | @Slf4j
26 | @ConditionalOnProperty(name = "rl.name", havingValue = "masac")
27 | public class MASACConfiguration {
28 | @Value("${rl.state-shape}")
29 | private int stateShape;
30 |
31 | @Value("${rl.hidden-shape}")
32 | private int hiddenShape;
33 |
34 | @Value("${rl.action-shape}")
35 | private int actionShape;
36 |
37 | @Value("${rl.learning-rate}")
38 | private float learningRate;
39 |
40 | @Value("${rl.alpha}")
41 | private float alpha;
42 |
43 | public Block createNetwork(NDManager manager, int outputDim) {
44 | var block = new SequentialBlock();
45 | block.add(Linear.builder().setUnits(hiddenShape).build());
46 | block.add(Activation::relu);
47 | block.add(Linear.builder().setUnits(hiddenShape).build());
48 | block.add(Activation::relu);
49 | block.add(Linear.builder().setUnits(outputDim).build());
50 | block.initialize(manager, DataType.FLOAT32, new Shape(stateShape));
51 | return block;
52 | }
53 |
54 | @Bean
55 | public Optimizer optimizer() {
56 | return Optimizer.adam().optLearningRateTracker(Tracker.fixed(learningRate)).build();
57 | }
58 |
59 | @Bean
60 | public Model q1Model(NDManager manager) {
61 | var model = Model.newInstance("q1");
62 | model.setBlock(createNetwork(manager, actionShape));
63 | return model;
64 | }
65 |
66 | @Bean
67 | public Model q2Model(NDManager manager) {
68 | var model = Model.newInstance("q2");
69 | model.setBlock(createNetwork(manager, actionShape));
70 | return model;
71 | }
72 |
73 | @Bean
74 | public Model targetQ1Model(NDManager manager) {
75 | var model = Model.newInstance("targetQ1");
76 | model.setBlock(createNetwork(manager, actionShape));
77 | return model;
78 | }
79 |
80 | @Bean
81 | public Model targetQ2Model(NDManager manager) {
82 | var model = Model.newInstance("targetQ2");
83 | model.setBlock(createNetwork(manager, actionShape));
84 | return model;
85 | }
86 |
87 | @Bean
88 | public Model criticModel(NDManager manager) {
89 | var model = Model.newInstance("critic");
90 | model.setBlock(createNetwork(manager, 1));
91 | return model;
92 | }
93 |
94 | @Bean
95 | public Model actorModel(NDManager manager) {
96 | var model = Model.newInstance("actor");
97 | model.setBlock(createNetwork(manager, actionShape));
98 | return model;
99 | }
100 |
101 | @Bean
102 | public Model alphaModel(NDManager manager) {
103 | var model = Model.newInstance("alpha");
104 | var block = new AlphaBlock(alpha);
105 | block.initialize(manager, DataType.FLOAT32, new Shape(1));
106 | model.setBlock(block);
107 | return model;
108 | }
109 |
110 | @Bean
111 | public Loss loss() {
112 | return new Loss("null") {
113 | @Override
114 | public NDArray evaluate(NDList ndList, NDList ndList1) {
115 | return null;
116 | }
117 | };
118 | }
119 |
120 | @Bean
121 | public TrainingConfig trainingConfig(Optimizer optimizer, Loss loss) {
122 | return new DefaultTrainingConfig(loss)
123 | .optOptimizer(optimizer);
124 | }
125 | }
126 |
--------------------------------------------------------------------------------
/edge-algorithm/src/main/java/cn/edu/scut/config/DJLConfig.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut.config;
2 |
3 | import ai.djl.engine.Engine;
4 | import ai.djl.ndarray.NDManager;
5 | import org.springframework.beans.factory.annotation.Value;
6 | import org.springframework.context.annotation.Bean;
7 | import org.springframework.context.annotation.Configuration;
8 |
9 | import java.util.Random;
10 |
11 | @Configuration
12 | public class DJLConfig {
13 |
14 | @Value("${edgeComputing.seed}")
15 | Integer seed;
16 |
17 | @Value("${spring.application.name}")
18 | String name;
19 |
20 | @Bean
21 | public Engine engine() {
22 | Engine engine = Engine.getInstance();
23 | engine.setRandomSeed(seed);
24 | return engine;
25 | }
26 |
27 | @Bean
28 | public NDManager manager(Engine engine) {
29 | return engine.newBaseManager();
30 | }
31 |
32 | @Bean
33 | public Random bufferRandom(){
34 | var random = new Random();
35 | random.setSeed(seed);
36 | return random;
37 | }
38 | }
39 |
--------------------------------------------------------------------------------
/edge-algorithm/src/main/java/cn/edu/scut/service/TransitionService.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut.service;
2 |
3 | import cn.edu.scut.bean.*;
4 | import cn.edu.scut.util.ArrayUtils;
5 | import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
6 | import lombok.extern.slf4j.Slf4j;
7 | import org.springframework.beans.factory.annotation.Autowired;
8 | import org.springframework.beans.factory.annotation.Value;
9 | import org.springframework.stereotype.Service;
10 |
11 | import java.util.ArrayList;
12 | import java.util.Arrays;
13 |
14 | @Service
15 | @Slf4j
16 | public class TransitionService {
17 |
18 | @Value("${edgeComputing.maxTaskRate}")
19 | private double maxTaskRate;
20 |
21 | @Value("${edgeComputing.maxCpuCore}")
22 | private int maxCpuCore;
23 |
24 | @Value("${edgeComputing.maxTaskSize}")
25 | private long maxTaskSize;
26 |
27 | @Value("${edgeComputing.maxTaskComplexity}")
28 | private long maxTaskComplexity;
29 |
30 | @Value("${edgeComputing.maxTransmissionRate}")
31 | private double maxTransmissionRate;
32 |
33 | @Value("${edgeComputing.maxTransmissionFailureRate}")
34 | private double maxTransmissionFailureRate;
35 |
36 | @Value("${edgeComputing.maxExecutionFailureRate}")
37 | private double maxExecutionFailureRate;
38 |
39 | @Value("${edgeComputing.deadline}")
40 | private int deadline;
41 |
42 | @Value("${edgeComputing.edgeNodeNumber}")
43 | private int agentNumber;
44 |
45 | @Autowired
46 | private EdgeNodeService edgeNodeService;
47 |
48 | @Autowired
49 | private TaskService taskService;
50 |
51 | @Autowired
52 | private LinkService linkService;
53 |
54 | public float[] getState(Long taskId) {
55 | var task = taskService.getById(taskId);
56 | var edgeNodeInfos = edgeNodeService.list();
57 | var linkInfos = linkService.list(new QueryWrapper().eq("source", task.getSource()));
58 |
59 | var obsList = new ArrayList();
60 | // one-hot vector N
61 | int id = Integer.parseInt(task.getSource().split("-")[2]) - 1;
62 | for (int j = 0; j < agentNumber; j++) {
63 | if (j == id) {
64 | obsList.add(1.0f);
65 | } else {
66 | obsList.add(0.0f);
67 | }
68 | }
69 | // task dynamic 2
70 | obsList.add(Float.valueOf(task.getTaskSize()) / (float) (maxTaskSize * StoreConstants.Byte.value * StoreConstants.Kilo.value));
71 | obsList.add(task.getTaskComplexity() / (float) maxTaskComplexity);
72 | // link static N
73 | for (Link link : linkInfos) {
74 | obsList.add((float) (link.getTransmissionRate() / (maxTransmissionRate * Constants.Mega.value * Constants.Byte.value)));
75 | }
76 | // edge node static 3N
77 | // static information
78 | for (EdgeNode edgeNode : edgeNodeInfos) {
79 | obsList.add((float) (edgeNode.getCpuNum() / maxCpuCore));
80 | obsList.add((float) (edgeNode.getExecutionFailureRate() / maxExecutionFailureRate));
81 | obsList.add((float) (edgeNode.getTaskRate() / maxTaskRate));
82 | }
83 | return ArrayUtils.toFloatArray(obsList);
84 | }
85 |
86 |
87 | public int getAction(Long taskId) {
88 | var task = taskService.getById(taskId);
89 | int action;
90 | if (task.getStatus().equals(TaskStatus.EMPTY)) {
91 | action = agentNumber;
92 | } else {
93 | action = Integer.parseInt(task.getDestination().split("-")[2]) - 1;
94 | }
95 | return action;
96 | }
97 |
98 | public int[] getAvailAction(Long taskId) {
99 | var task = taskService.getById(taskId);
100 | return Arrays.stream(task.getAvailAction().split(",")).mapToInt(Integer::parseInt).toArray();
101 | }
102 |
103 | public float getReward(Long taskId) {
104 | var task = taskService.getById(taskId);
105 | float reward;
106 | if (task.getStatus().equals(TaskStatus.SUCCESS)) {
107 | reward = 1.0f;
108 | } else if (task.getStatus().equals(TaskStatus.EXECUTION_FAILURE) || task.getStatus().equals(TaskStatus.TRANSMISSION_FAILURE) || task.getStatus().equals(TaskStatus.DROP)) {
109 | reward = -1.0f;
110 | } else if (task.getStatus().equals(TaskStatus.EMPTY)) {
111 | return 0.0f;
112 | } else {
113 | throw new RuntimeException("task status error");
114 | }
115 | return reward;
116 | }
117 | }
--------------------------------------------------------------------------------
/edge-algorithm/src/main/java/cn/edu/scut/util/DJLUtils.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut.util;
2 |
3 | import ai.djl.MalformedModelException;
4 | import ai.djl.ndarray.NDArray;
5 | import ai.djl.ndarray.NDManager;
6 | import ai.djl.ndarray.index.NDIndex;
7 | import ai.djl.nn.Block;
8 | import ai.djl.nn.ParameterList;
9 | import lombok.extern.slf4j.Slf4j;
10 |
11 | import java.io.*;
12 | import java.nio.file.Files;
13 | import java.nio.file.Path;
14 | import java.util.List;
15 | import java.util.Random;
16 |
17 | @Slf4j
18 | public class DJLUtils {
19 | public static int sampleMultinomial(Random random, NDArray prob) {
20 | int value = 0;
21 | long size = prob.size();
22 | float rnd = random.nextFloat();
23 | for (int i = 0; i < size; i++) {
24 | float cut = prob.getFloat(value);
25 | if (rnd > cut) {
26 | value++;
27 | } else {
28 | return value;
29 | }
30 | rnd -= cut;
31 | }
32 | throw new IllegalArgumentException("Invalid multinomial distribution");
33 | }
34 |
35 | public static void softUpdate(Block network, Block targetNetwork, float tau) {
36 | var parameters = network.getParameters();
37 | var targetParameters = targetNetwork.getParameters();
38 | for (String key : parameters.keys()) {
39 | var array = parameters.get(key).getArray();
40 | var targetArray = targetParameters.get(key).getArray();
41 | var mixArray = targetArray.mul(1.0f - tau).add(array.mul(tau));
42 | targetParameters.get(key).getArray().set(mixArray.toFloatArray());
43 | }
44 | }
45 |
46 | public static void hardUpdate(Block network, Block targetNetwork) {
47 | var parameters = network.getParameters();
48 | var targetParameters = targetNetwork.getParameters();
49 | for (String key : parameters.keys()) {
50 | targetParameters.get(key).getArray().set(parameters.get(key).getArray().toFloatArray());
51 | }
52 | }
53 |
54 |
55 | public static void saveModel(Path path, Block actorBlock) {
56 | try {
57 | Files.createDirectories(path.getParent());
58 | } catch (IOException e) {
59 | throw new RuntimeException(e);
60 | }
61 | try (var fileOutputStream = new FileOutputStream(path.toFile());
62 | var dataOutputStream = new DataOutputStream(fileOutputStream)) {
63 | actorBlock.saveParameters(dataOutputStream);
64 | } catch (IOException e) {
65 | log.error("save model error");
66 | throw new RuntimeException(e);
67 | }
68 | }
69 |
70 | public static void loadModel(Path path, Block actorBlock, NDManager manager) {
71 | try (var fileInputStream = new FileInputStream(path.toFile());
72 | var dataInputStream = new DataInputStream(fileInputStream)) {
73 | actorBlock.loadParameters(manager, dataInputStream);
74 | ParameterList parameters = actorBlock.getParameters();
75 | List keys = parameters.keys();
76 | for (String key : keys) {
77 | var data = parameters.get(key).getArray();
78 | data.setRequiresGradient(true);
79 | }
80 | } catch (IOException e) {
81 | log.error("read file error");
82 | throw new RuntimeException(e);
83 | } catch (MalformedModelException e) {
84 | log.error("load model error");
85 | throw new RuntimeException(e);
86 | }
87 | }
88 |
89 | public static void loadStreamModel(InputStream inputStream, Block actorBlock, NDManager manager) {
90 | try (var dataInputStream = new DataInputStream(inputStream)) {
91 | log.info("load stream model.");
92 | actorBlock.loadParameters(manager, dataInputStream);
93 | ParameterList parameters = actorBlock.getParameters();
94 | List keys = parameters.keys();
95 | for (String key : keys) {
96 | var data = parameters.get(key).getArray();
97 | data.setRequiresGradient(true);
98 | }
99 | } catch (IOException e) {
100 | log.error("IOException");
101 | throw new RuntimeException(e);
102 | } catch (MalformedModelException e) {
103 | log.error("MalformedModelException");
104 | throw new RuntimeException(e);
105 | }
106 | }
107 |
108 | // edge computing, truncate episode, no done
109 | public static NDArray getReturn(NDArray rewards, NDArray nextStateValues, float gamma, int episodeLimit) {
110 | var res = rewards.zerosLike();
111 | for (long i = episodeLimit - 1; i >= 0; i--) {
112 | NDArray nextReturn;
113 | if (i == episodeLimit - 1) {
114 | nextReturn = nextStateValues.get(episodeLimit - 1);
115 | } else {
116 | nextReturn = res.get(i + 1);
117 | }
118 | var currentReturn = rewards.get(i).add(nextReturn.mul(gamma));
119 | res.set(new NDIndex(i), currentReturn);
120 | }
121 | return res;
122 | }
123 |
124 | public static NDArray getReturn(NDArray rewards, float gamma, int episodeLimit) {
125 | var res = rewards.zerosLike();
126 | for (long i = episodeLimit - 1; i >= 0; i--) {
127 | NDArray nextReturn;
128 | if (i == episodeLimit - 1) {
129 | nextReturn = rewards.get(i);
130 | } else {
131 | nextReturn = res.get(i + 1);
132 | }
133 | var currentReturn = rewards.get(i).add(nextReturn.mul(gamma));
134 | res.set(new NDIndex(i), currentReturn);
135 | }
136 | return res;
137 | }
138 |
139 |
140 | public static NDArray getGae(NDArray rewards, NDArray stateValues, NDArray nextStateValues, NDManager manager, float gamma, int episodeLimit, float gaeLambda) {
141 | var advantages = rewards.zerosLike();
142 | var deltas = rewards.add(nextStateValues.mul(gamma)).sub(stateValues);
143 | var nextAdvantage = manager.create(0.0f);
144 | for (long i = episodeLimit - 1; i >= 0; i--) {
145 | var advantage = deltas.get(i).add(nextAdvantage.mul(gamma).mul(gaeLambda));
146 | advantages.set(new NDIndex(i), advantage);
147 | nextAdvantage = advantage;
148 | }
149 | return advantages;
150 | }
151 | }
152 |
--------------------------------------------------------------------------------
/edge-algorithm/src/main/java/cn/edu/scut/util/FileUtils.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut.util;
2 |
3 | import lombok.extern.slf4j.Slf4j;
4 |
5 | import java.io.*;
6 | import java.nio.file.Files;
7 | import java.nio.file.Path;
8 |
9 | @Slf4j
10 | public class FileUtils {
11 | public static void recursiveDelete(File file) {
12 | if (!file.exists())
13 | return;
14 |
15 | if (file.isDirectory()) {
16 | for (File f : file.listFiles()) {
17 | // 调用递归
18 | recursiveDelete(f);
19 | }
20 | }
21 | file.delete();
22 | log.info("delete file/folder: {}", file.getAbsolutePath());
23 | }
24 |
25 | public static void writeObject(Object data, Path path) {
26 | try {
27 | Files.createDirectories(path.getParent());
28 | } catch (IOException e) {
29 | throw new RuntimeException(e);
30 | }
31 | try (FileOutputStream f = new FileOutputStream(path.toFile());
32 | ObjectOutput s = new ObjectOutputStream(f)) {
33 | s.writeObject(data);
34 | } catch (IOException e) {
35 | throw new RuntimeException(e);
36 | }
37 | }
38 |
39 | public static Object readObject(Path path){
40 | try (FileInputStream in = new FileInputStream(path.toFile());
41 | ObjectInputStream s = new ObjectInputStream(in)) {
42 | return s.readObject();
43 | } catch (IOException | ClassNotFoundException e) {
44 | throw new RuntimeException(e);
45 | }
46 | }
47 |
48 | }
49 |
--------------------------------------------------------------------------------
/edge-algorithm/src/main/java/cn/edu/scut/util/PlotUtils.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut.util;
2 |
3 | import tech.tablesaw.plotly.components.Axis;
4 | import tech.tablesaw.plotly.components.Figure;
5 | import tech.tablesaw.plotly.components.Layout;
6 | import tech.tablesaw.plotly.traces.ScatterTrace;
7 |
8 | public class PlotUtils {
9 | public static Figure plot(double[][] x, double[][] y, String[] traceLabels, String xLabel, String yLabel) {
10 | ScatterTrace[] traces = new ScatterTrace[x.length];
11 | for (int i = 0; i < traces.length; i++) {
12 | traces[i] =
13 | ScatterTrace.builder(x[i], y[i])
14 | .mode(ScatterTrace.Mode.LINE)
15 | .name(traceLabels[i])
16 | .build();
17 | }
18 | Layout layout =
19 | Layout.builder()
20 | .showLegend(true)
21 | .xAxis(Axis.builder().title(xLabel).build())
22 | .yAxis(Axis.builder().title(yLabel).build())
23 | .build();
24 | return new Figure(layout, traces);
25 | }
26 | }
27 |
--------------------------------------------------------------------------------
/edge-api/README.md:
--------------------------------------------------------------------------------
1 | DAO
2 | utils
--------------------------------------------------------------------------------
/edge-api/pom.xml:
--------------------------------------------------------------------------------
1 |
2 |
5 |
6 | edge-computing
7 | cn.edu.scut
8 | 1.0-SNAPSHOT
9 |
10 | 4.0.0
11 |
12 | edge-api
13 |
14 |
15 | 17
16 | 17
17 | UTF-8
18 |
19 |
20 |
21 |
22 | com.alibaba.cloud
23 | spring-cloud-starter-alibaba-nacos-discovery
24 |
25 |
26 |
27 | mysql
28 | mysql-connector-java
29 |
30 |
31 | com.alibaba
32 | druid-spring-boot-starter
33 |
34 |
35 | com.baomidou
36 | mybatis-plus-boot-starter
37 |
38 |
39 |
40 | org.springframework.boot
41 | spring-boot-starter-web
42 |
43 |
44 | ai.djl
45 | api
46 |
47 |
48 |
49 | org.projectlombok
50 | lombok
51 | true
52 |
53 |
54 |
55 | org.apache.hadoop
56 | hadoop-common
57 |
58 |
59 |
60 | org.apache.hadoop
61 | hadoop-hdfs
62 |
63 |
64 |
65 | org.apache.hadoop
66 | hadoop-hdfs-client
67 |
68 |
69 |
70 |
--------------------------------------------------------------------------------
/edge-api/src/main/java/cn/edu/scut/bean/Constants.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut.bean;
2 |
3 |
4 | public enum Constants {
5 |
6 | Byte(8), Kilo(1000), Mega(1000 * 1000), Giga(1000 * 1000 * 1000);
7 |
8 | public Integer value;
9 |
10 | Constants(Integer value) {
11 | this.value = value;
12 | }
13 | }
14 |
--------------------------------------------------------------------------------
/edge-api/src/main/java/cn/edu/scut/bean/EdgeNode.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut.bean;
2 |
3 | import com.baomidou.mybatisplus.annotation.TableField;
4 | import lombok.Data;
5 |
6 | @Data
7 | public class EdgeNode {
8 | private Integer id;
9 | private String name;
10 | private Long cpuNum;
11 | @TableField(exist = false)
12 | private Long capacity; // capacity = cpuNum * cpuCapacity
13 | private Double taskRate;
14 | private Double executionFailureRate;
15 | }
--------------------------------------------------------------------------------
/edge-api/src/main/java/cn/edu/scut/bean/Link.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut.bean;
2 |
3 | import lombok.Data;
4 |
5 | @Data
6 | public class Link {
7 | private Integer id;
8 | private String source;
9 | private String destination;
10 | private Double transmissionRate;
11 | private Double transmissionFailureRate;
12 | }
13 |
--------------------------------------------------------------------------------
/edge-api/src/main/java/cn/edu/scut/bean/RatcVo.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut.bean;
2 |
3 | import lombok.Data;
4 |
5 | @Data
6 | public class RatcVo {
7 | private String edgeId;
8 | private Double executionFailureRate;
9 | private Long capacity;
10 | private Long waitingTime;
11 | private Long totalTime;
12 | }
13 |
--------------------------------------------------------------------------------
/edge-api/src/main/java/cn/edu/scut/bean/StoreConstants.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut.bean;
2 |
3 | public enum StoreConstants {
4 |
5 | Byte(8), Kilo(1024), Mega(1024 * 1024), Giga(1024 * 1024 * 1024);
6 |
7 | public Integer value;
8 |
9 | StoreConstants(Integer value) {
10 | this.value = value;
11 | }
12 | }
13 |
--------------------------------------------------------------------------------
/edge-api/src/main/java/cn/edu/scut/bean/Task.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut.bean;
2 |
3 | import com.baomidou.mybatisplus.annotation.TableField;
4 | import lombok.Data;
5 |
6 | import java.time.LocalDateTime;
7 |
8 | @Data
9 | public class Task {
10 | private Long id;
11 | private Integer timeSlot;
12 | private String source;
13 | private String destination;
14 | private TaskStatus status;
15 | // KB
16 | private Long taskSize;
17 | private Long taskComplexity;
18 | // cycle
19 | private Long cpuCycle; // cpuCycle = taskSize * taskComplexity
20 | // s
21 | private Long deadline;
22 | private Long transmissionTime;
23 | private Long executionTime;
24 | private Long transmissionWaitingTime;
25 | private Long executionWaitingTime;
26 |
27 | @TableField(exist = false)
28 | private LocalDateTime arrivalTime;
29 | @TableField(exist = false)
30 | private LocalDateTime beginTransmissionTime;
31 | @TableField(exist = false)
32 | private LocalDateTime endTransmissionTime;
33 | @TableField(exist = false)
34 | private LocalDateTime beginExecutionTime;
35 | @TableField(exist = false)
36 | private LocalDateTime endExecutionTime;
37 | @TableField(exist = false)
38 | private Double transmissionFailureRate;
39 | @TableField(exist = false)
40 | private Double executionFailureRate;
41 |
42 | private String availAction;
43 | }
44 |
--------------------------------------------------------------------------------
/edge-api/src/main/java/cn/edu/scut/bean/TaskStatus.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut.bean;
2 |
3 | public enum TaskStatus {
4 | EMPTY, NEW, SUCCESS, DROP, TRANSMISSION_FAILURE, EXECUTION_FAILURE
5 | }
6 |
--------------------------------------------------------------------------------
/edge-api/src/main/java/cn/edu/scut/mapper/EdgeNodeMapper.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut.mapper;
2 |
3 | import cn.edu.scut.bean.EdgeNode;
4 | import com.baomidou.mybatisplus.core.mapper.BaseMapper;
5 | import org.apache.ibatis.annotations.Mapper;
6 |
7 | @Mapper
8 | public interface EdgeNodeMapper extends BaseMapper {
9 | }
10 |
--------------------------------------------------------------------------------
/edge-api/src/main/java/cn/edu/scut/mapper/LinkMapper.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut.mapper;
2 |
3 | import cn.edu.scut.bean.Link;
4 | import com.baomidou.mybatisplus.core.mapper.BaseMapper;
5 | import org.apache.ibatis.annotations.Mapper;
6 |
7 | @Mapper
8 | public interface LinkMapper extends BaseMapper {
9 | }
10 |
--------------------------------------------------------------------------------
/edge-api/src/main/java/cn/edu/scut/mapper/TaskMapper.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut.mapper;
2 |
3 | import cn.edu.scut.bean.Task;
4 | import com.baomidou.mybatisplus.core.mapper.BaseMapper;
5 | import org.apache.ibatis.annotations.Mapper;
6 |
7 | @Mapper
8 | public interface TaskMapper extends BaseMapper {
9 | }
10 |
--------------------------------------------------------------------------------
/edge-api/src/main/java/cn/edu/scut/service/EdgeNodeService.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut.service;
2 |
3 | import cn.edu.scut.bean.EdgeNode;
4 | import cn.edu.scut.mapper.EdgeNodeMapper;
5 | import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
6 | import org.springframework.stereotype.Service;
7 |
8 | @Service
9 | public class EdgeNodeService extends ServiceImpl {
10 | }
11 |
--------------------------------------------------------------------------------
/edge-api/src/main/java/cn/edu/scut/service/LinkService.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut.service;
2 |
3 | import cn.edu.scut.bean.Link;
4 | import cn.edu.scut.mapper.LinkMapper;
5 | import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
6 | import org.springframework.stereotype.Service;
7 |
8 | @Service
9 | public class LinkService extends ServiceImpl {
10 | }
11 |
--------------------------------------------------------------------------------
/edge-api/src/main/java/cn/edu/scut/service/TaskService.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut.service;
2 |
3 | import cn.edu.scut.bean.Task;
4 | import cn.edu.scut.mapper.TaskMapper;
5 | import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
6 | import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
7 | import org.springframework.beans.factory.annotation.Autowired;
8 | import org.springframework.stereotype.Service;
9 |
10 | import java.util.Map;
11 |
12 | @Service
13 | public class TaskService extends ServiceImpl {
14 |
15 | @Autowired
16 | private TaskMapper taskMapper;
17 |
18 | public float getSuccessRate(Map map) {
19 | Long successTasks = taskMapper.selectCount(new QueryWrapper().eq("status", "SUCCESS").gt("id", map.get("startIndex")));
20 | Long totalTasks = taskMapper.selectCount(new QueryWrapper().gt("id", map.get("startIndex")));
21 | return Float.valueOf(successTasks) / Float.valueOf(totalTasks);
22 | }
23 |
24 | public double getSuccessRate() {
25 | Long successTasks = taskMapper.selectCount(new QueryWrapper().eq("status", "SUCCESS"));
26 | Long totalTasks = taskMapper.selectCount(new QueryWrapper()) - taskMapper.selectCount(new QueryWrapper().eq("status", "EMPTY"));
27 | return Double.valueOf(successTasks) / Double.valueOf(totalTasks);
28 | }
29 | }
30 |
--------------------------------------------------------------------------------
/edge-api/src/main/java/cn/edu/scut/util/ArrayUtils.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut.util;
2 |
3 | import java.util.Arrays;
4 | import java.util.List;
5 | import java.util.stream.Collectors;
6 |
7 | public class ArrayUtils {
8 | public static float[] toFloatArray(List list) {
9 | var res = new float[list.size()];
10 | for (int i = 0; i < list.size(); i++) {
11 | res[i] = list.get(i);
12 | }
13 | return res;
14 | }
15 |
16 | public static double[] toDoubleArray(List list) {
17 | var res = new double[list.size()];
18 | for (int i = 0; i < list.size(); i++) {
19 | res[i] = list.get(i);
20 | }
21 | return res;
22 | }
23 |
24 | public static String arrayToString(int[] array) {
25 | return Arrays.stream(array).mapToObj(Integer::toString).collect(Collectors.joining(","));
26 | }
27 |
28 | public static double[] floatToDoubleArray(float[] x) {
29 | double[] ret = new double[x.length];
30 | for (int i = 0; i < x.length; i++) {
31 | ret[i] = x[i];
32 | }
33 | return ret;
34 | }
35 | }
36 |
--------------------------------------------------------------------------------
/edge-controller/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM lhc-ubuntu-20:v1.1
2 | LABEL maintainer=hongcai
3 | ENV PATH /root/lhc_dev/jdk-17.0.5/bin:${PATH}
4 | COPY target/edge-controller-1.0-SNAPSHOT.jar /root/app.jar
5 | WORKDIR /root
6 | ENTRYPOINT ["sh","-c","java -jar app.jar"]
--------------------------------------------------------------------------------
/edge-controller/README.md:
--------------------------------------------------------------------------------
1 | mock the users to seed task to edge nodes.
2 | generate the configuration of task, edge node and link.
3 | start, stop the experiment.
4 |
--------------------------------------------------------------------------------
/edge-controller/pom.xml:
--------------------------------------------------------------------------------
1 |
2 |
5 |
6 | edge-computing
7 | cn.edu.scut
8 | 1.0-SNAPSHOT
9 |
10 | 4.0.0
11 |
12 | edge-controller
13 |
14 |
15 | 17
16 | 17
17 |
18 |
19 |
20 |
21 |
22 | mysql
23 | mysql-connector-java
24 |
25 |
26 | com.alibaba
27 | druid-spring-boot-starter
28 |
29 |
30 | com.baomidou
31 | mybatis-plus-boot-starter
32 |
33 |
34 |
35 |
36 |
37 | com.alibaba.cloud
38 | spring-cloud-starter-alibaba-nacos-config
39 |
40 |
41 | org.springframework.cloud
42 | spring-cloud-starter-bootstrap
43 |
44 |
45 |
46 | com.alibaba.cloud
47 | spring-cloud-starter-alibaba-nacos-discovery
48 |
49 |
50 |
51 |
52 | org.springframework.boot
53 | spring-boot-starter-web
54 |
55 |
56 |
57 | org.springframework.boot
58 | spring-boot-starter-logging
59 |
60 |
61 |
62 | org.projectlombok
63 | lombok
64 | true
65 |
66 |
67 |
68 |
69 | org.springframework.boot
70 | spring-boot-starter-test
71 | test
72 |
73 |
74 |
75 |
76 | org.springframework.cloud
77 | spring-cloud-loadbalancer
78 |
79 |
80 |
81 | cn.edu.scut
82 | edge-api
83 | 1.0-SNAPSHOT
84 |
85 |
86 |
87 | org.apache.commons
88 | commons-math3
89 |
90 |
91 | cn.edu.scut
92 | edge-algorithm
93 | 1.0-SNAPSHOT
94 | test
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 | org.springframework.boot
103 | spring-boot-maven-plugin
104 |
105 | true
106 | true
107 |
108 |
109 |
110 |
111 | repackage
112 |
113 |
114 |
115 |
116 |
117 |
118 |
--------------------------------------------------------------------------------
/edge-controller/src/main/java/cn/edu/scut/ControllerApp.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut;
2 |
3 | import cn.edu.scut.utils.SpringBeanUtils;
4 | import org.springframework.boot.SpringApplication;
5 | import org.springframework.boot.autoconfigure.SpringBootApplication;
6 | import org.springframework.cloud.client.discovery.EnableDiscoveryClient;
7 | import org.springframework.context.ConfigurableApplicationContext;
8 | import org.springframework.scheduling.annotation.EnableAsync;
9 |
10 | @EnableDiscoveryClient
11 | @SpringBootApplication
12 | @EnableAsync
13 | public class ControllerApp {
14 | public static void main(String[] args) {
15 | ConfigurableApplicationContext context = SpringApplication.run(ControllerApp.class, args);
16 | SpringBeanUtils.setApplicationContext(context);
17 | }
18 | }
--------------------------------------------------------------------------------
/edge-controller/src/main/java/cn/edu/scut/config/DataGenerationConfig.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut.config;
2 |
3 | import org.apache.commons.math3.distribution.UniformIntegerDistribution;
4 | import org.apache.commons.math3.distribution.UniformRealDistribution;
5 | import org.apache.commons.math3.random.JDKRandomGenerator;
6 | import org.apache.commons.math3.random.RandomGenerator;
7 | import org.springframework.beans.factory.annotation.Value;
8 | import org.springframework.context.annotation.Bean;
9 | import org.springframework.context.annotation.Configuration;
10 |
11 | @Configuration
12 | public class DataGenerationConfig {
13 |
14 | @Value("${edgeComputing.minExecutionFailureRate}")
15 | private Double minExecutionFailureRate;
16 |
17 | @Value("${edgeComputing.maxExecutionFailureRate}")
18 | private Double maxExecutionFailureRate;
19 |
20 | @Value("${edgeComputing.maxCpuCore}")
21 | private Integer maxCpuCore;
22 |
23 | @Value("${edgeComputing.minCpuCore}")
24 | private Integer minCpuCore;
25 |
26 | @Value("${edgeComputing.minTaskRate}")
27 | private Double minTaskRate;
28 |
29 | @Value("${edgeComputing.maxTaskRate}")
30 | private Double maxTaskRate;
31 |
32 | @Value("${edgeComputing.minTransmissionRate}")
33 | private Double minTransmissionRate;
34 |
35 | @Value("${edgeComputing.maxTransmissionRate}")
36 | private Double maxTransmissionRate;
37 |
38 | @Value("${edgeComputing.minTransmissionFailureRate}")
39 | private Double minTransmissionFailureRate;
40 |
41 | @Value("${edgeComputing.maxTransmissionFailureRate}")
42 | private Double maxTransmissionFailureRate;
43 |
44 | @Value("${edgeComputing.edgeNodeSeed}")
45 | private int edgeNodeSeed;
46 |
47 | @Value("${edgeComputing.taskSeed}")
48 | private int taskSeed;
49 |
50 | @Value("${edgeComputing.minTaskSize}")
51 | private int minTaskSize;
52 |
53 | @Value("${edgeComputing.maxTaskSize}")
54 | private int maxTaskSize;
55 |
56 | @Value("${edgeComputing.minTaskComplexity}")
57 | private int minTaskComplexity;
58 |
59 | @Value("${edgeComputing.maxTaskComplexity}")
60 | private int maxTaskComplexity;
61 |
62 | @Bean
63 | public RandomGenerator randomGenerator() {
64 | var random = new JDKRandomGenerator();
65 | random.setSeed(edgeNodeSeed);
66 | return random;
67 | }
68 |
69 | @Bean
70 | public UniformRealDistribution executionFailureRandom(RandomGenerator randomGenerator) {
71 | return new UniformRealDistribution(randomGenerator, minExecutionFailureRate, maxExecutionFailureRate);
72 | }
73 |
74 | @Bean
75 | public UniformIntegerDistribution cpuCoreRandom(RandomGenerator randomGenerator) {
76 | return new UniformIntegerDistribution(randomGenerator, minCpuCore / 4, maxCpuCore / 4);
77 | }
78 |
79 | @Bean
80 | public UniformRealDistribution taskRateRandom(RandomGenerator randomGenerator) {
81 | return new UniformRealDistribution(randomGenerator, minTaskRate, maxTaskRate);
82 | }
83 |
84 | @Bean
85 | public UniformRealDistribution transmissionRateRandom(RandomGenerator randomGenerator) {
86 | return new UniformRealDistribution(randomGenerator, minTransmissionRate, maxTransmissionRate);
87 | }
88 |
89 | @Bean
90 | public UniformRealDistribution transmissionFailureRateRandom(RandomGenerator randomGenerator) {
91 | return new UniformRealDistribution(randomGenerator, minTransmissionFailureRate, maxTransmissionFailureRate);
92 | }
93 |
94 | @Bean
95 | public RandomGenerator taskRandomGenerator() {
96 | var random = new JDKRandomGenerator();
97 | random.setSeed(taskSeed);
98 | return random;
99 | }
100 |
101 | @Bean
102 | public UniformIntegerDistribution taskSizeRandom(RandomGenerator taskRandomGenerator) {
103 | return new UniformIntegerDistribution(taskRandomGenerator, minTaskSize, maxTaskSize);
104 | }
105 |
106 | @Bean
107 | public UniformIntegerDistribution taskComplexityRandom(RandomGenerator taskRandomGenerator) {
108 | return new UniformIntegerDistribution(taskRandomGenerator, minTaskComplexity, maxTaskComplexity);
109 | }
110 | }
111 |
--------------------------------------------------------------------------------
/edge-controller/src/main/java/cn/edu/scut/config/SpringConfig.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut.config;
2 |
3 | import org.springframework.cloud.client.loadbalancer.LoadBalanced;
4 | import org.springframework.context.annotation.Bean;
5 | import org.springframework.context.annotation.Configuration;
6 | import org.springframework.web.client.RestTemplate;
7 |
8 | @Configuration
9 | public class SpringConfig {
10 |
11 | @Bean
12 | @LoadBalanced
13 | RestTemplate restTemplate() {
14 | return new RestTemplate();
15 | }
16 |
17 | }
18 |
--------------------------------------------------------------------------------
/edge-controller/src/main/java/cn/edu/scut/controller/EdgeNodeController.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut.controller;
2 |
3 | import cn.edu.scut.service.EdgeConfigService;
4 | import cn.edu.scut.service.UserService;
5 | import lombok.extern.apachecommons.CommonsLog;
6 | import org.springframework.beans.factory.annotation.Autowired;
7 | import org.springframework.beans.factory.annotation.Value;
8 | import org.springframework.web.bind.annotation.GetMapping;
9 | import org.springframework.web.bind.annotation.RestController;
10 | import org.springframework.web.client.RestTemplate;
11 |
12 | @RestController
13 | @CommonsLog
14 | public class EdgeNodeController {
15 | @Autowired
16 | private RestTemplate restTemplate;
17 |
18 | @Autowired
19 | private EdgeConfigService edgeConfigService;
20 |
21 | @Autowired
22 | private UserService userService;
23 |
24 | @Value("${edgeComputing.edgeNodeNumber}")
25 | private int edgeNodeNumber;
26 |
27 | @GetMapping("/generate")
28 | public String generate() {
29 | log.info("generate all edge node configuration.");
30 | for (int i = 1; i <= edgeNodeNumber; i++) {
31 | String name = String.format("edge-node-%d", i);
32 | edgeConfigService.generateEdgeLink(name);
33 | }
34 | return "success\n";
35 | }
36 |
37 | @GetMapping("/init")
38 | public String init() {
39 | log.info("init all edge node configuration.");
40 | for (int i = 1; i <= edgeNodeNumber; i++) {
41 | String name = String.format("edge-node-%d", i);
42 | log.info("init " + name);
43 | String url = String.format("http://%s/cmd/init", name);
44 | restTemplate.getForObject(url, String.class);
45 | }
46 | return "success\n";
47 | }
48 |
49 | @GetMapping("/start")
50 | public String start() {
51 | log.info("start experiment.");
52 | userService.start();
53 | return "success\n";
54 | }
55 |
56 | @GetMapping("/stop")
57 | public String stop() {
58 | log.info("stop experiment.");
59 | userService.stop();
60 | return "success\n";
61 | }
62 |
63 | @GetMapping("/restart")
64 | public String restart() {
65 | log.info("restart experiment.");
66 | userService.restart();
67 | return "success\n";
68 | }
69 | }
70 |
--------------------------------------------------------------------------------
/edge-controller/src/main/java/cn/edu/scut/service/EdgeConfigService.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut.service;
2 |
3 | import cn.edu.scut.bean.*;
4 | import cn.edu.scut.util.ArrayUtils;
5 | import lombok.extern.slf4j.Slf4j;
6 | import org.apache.commons.math3.distribution.UniformIntegerDistribution;
7 | import org.apache.commons.math3.distribution.UniformRealDistribution;
8 | import org.springframework.beans.factory.annotation.Autowired;
9 | import org.springframework.beans.factory.annotation.Value;
10 | import org.springframework.cloud.context.config.annotation.RefreshScope;
11 | import org.springframework.stereotype.Service;
12 |
13 | @Service
14 | @Slf4j
15 | @RefreshScope
16 | public class EdgeConfigService {
17 |
18 | @Autowired
19 | private UniformRealDistribution executionFailureRandom;
20 |
21 | @Autowired
22 | private UniformIntegerDistribution cpuCoreRandom;
23 |
24 | @Autowired
25 | private UniformRealDistribution taskRateRandom;
26 |
27 | @Autowired
28 | private UniformRealDistribution transmissionRateRandom;
29 |
30 | @Autowired
31 | private UniformRealDistribution transmissionFailureRateRandom;
32 |
33 | @Autowired
34 | private UniformIntegerDistribution taskSizeRandom;
35 |
36 | @Autowired
37 | private UniformIntegerDistribution taskComplexityRandom;
38 |
39 | @Autowired
40 | private EdgeNodeService edgeNodeService;
41 |
42 | @Autowired
43 | private LinkService linkService;
44 |
45 | @Value("${edgeComputing.minTransmissionFailureRate}")
46 | Double minTransmissionFailureRate;
47 |
48 | @Value("${edgeComputing.maxTransmissionRate}")
49 | Double maxTransmissionRate;
50 |
51 | @Value("${edgeComputing.minTaskSize}")
52 | Long minTaskSize;
53 |
54 | @Value("${edgeComputing.maxTaskSize}")
55 | Long maxTaskSize;
56 |
57 | @Value("${edgeComputing.minTaskComplexity}")
58 | Long minTaskComplexity;
59 |
60 | @Value("${edgeComputing.maxTaskComplexity}")
61 | Long maxTaskComplexity;
62 |
63 | @Value("${edgeComputing.deadline}")
64 | Long deadline;
65 |
66 | @Value("${edgeComputing.edgeNodeNumber}")
67 | private int edgeNodeNumber;
68 |
69 | public void generateEdgeLink(String name) {
70 | EdgeNode edgeNode = new EdgeNode();
71 | edgeNode.setName(name);
72 | edgeNode.setExecutionFailureRate(executionFailureRandom.sample());
73 | // ! int -> long
74 | Integer core = cpuCoreRandom.sample() * 4;
75 | edgeNode.setCpuNum(core.longValue());
76 | edgeNode.setTaskRate(taskRateRandom.sample());
77 | edgeNodeService.save(edgeNode);
78 |
79 | for (int i = 1; i <= edgeNodeNumber; i++) {
80 | String service = String.format("edge-node-%d", i);
81 | Link link = new Link();
82 | if (service.equals(name)) {
83 | link.setSource(name);
84 | link.setDestination(name);
85 | link.setTransmissionRate(maxTransmissionRate * Constants.Mega.value * Constants.Byte.value);
86 | link.setTransmissionFailureRate(minTransmissionFailureRate);
87 | } else {
88 | link.setSource(name);
89 | link.setDestination(service);
90 | link.setTransmissionRate(transmissionRateRandom.sample() * Constants.Mega.value * Constants.Byte.value);
91 | link.setTransmissionFailureRate(transmissionFailureRateRandom.sample());
92 | }
93 | linkService.save(link);
94 | }
95 | }
96 |
97 | public Task generateTask(String name) {
98 | Task task = new Task();
99 | task.setTaskSize(((long) taskSizeRandom.sample()) * StoreConstants.Kilo.value * StoreConstants.Byte.value);
100 | task.setTaskComplexity((long) taskComplexityRandom.sample());
101 | task.setCpuCycle(task.getTaskComplexity() * task.getTaskSize());
102 | task.setDeadline(deadline);
103 | task.setSource(name);
104 | task.setStatus(TaskStatus.NEW);
105 | var availAction = new int[edgeNodeNumber + 1];
106 | for (int i = 0; i < edgeNodeNumber; i++) {
107 | availAction[i] = 1;
108 | }
109 | task.setAvailAction(ArrayUtils.arrayToString(availAction));
110 | return task;
111 | }
112 |
113 | public Task generateEmptyTask(String name) {
114 | Task task = new Task();
115 | task.setTaskSize(0L);
116 | task.setTaskComplexity(0L);
117 | task.setCpuCycle(0L);
118 | task.setDeadline(0L);
119 | task.setSource(name);
120 | task.setDestination("null");
121 | task.setStatus(TaskStatus.EMPTY);
122 | var availAction = new int[edgeNodeNumber + 1];
123 | availAction[edgeNodeNumber] = 1;
124 | task.setAvailAction(ArrayUtils.arrayToString(availAction));
125 | task.setExecutionTime(0L);
126 | task.setExecutionWaitingTime(0L);
127 | task.setTransmissionTime(0L);
128 | task.setTransmissionWaitingTime(0L);
129 | return task;
130 | }
131 | }
--------------------------------------------------------------------------------
/edge-controller/src/main/java/cn/edu/scut/service/UserService.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut.service;
2 |
3 | import cn.edu.scut.thread.UserRunnable;
4 | import cn.edu.scut.utils.SpringBeanUtils;
5 | import org.springframework.beans.factory.annotation.Value;
6 | import org.springframework.stereotype.Service;
7 |
8 | import java.util.concurrent.ScheduledThreadPoolExecutor;
9 | import java.util.concurrent.TimeUnit;
10 |
11 | @Service
12 | public class UserService {
13 |
14 | @Value("${edgeComputing.timeSlot}")
15 | private int timeSlot;
16 |
17 | private ScheduledThreadPoolExecutor threadPoolExecutor = new ScheduledThreadPoolExecutor(1);
18 |
19 | public void start() {
20 | UserRunnable userRunnable = SpringBeanUtils.applicationContext.getBean(UserRunnable.class);
21 | threadPoolExecutor.scheduleAtFixedRate(userRunnable, 0, timeSlot, TimeUnit.MILLISECONDS);
22 | }
23 |
24 | public void stop() {
25 | threadPoolExecutor.shutdown();
26 | }
27 |
28 | public void restart() {
29 | threadPoolExecutor = new ScheduledThreadPoolExecutor(1);
30 | UserRunnable userRunnable = SpringBeanUtils.applicationContext.getBean(UserRunnable.class);
31 | userRunnable.setTimeSlot(0);
32 | threadPoolExecutor.scheduleAtFixedRate(userRunnable, 0, timeSlot, TimeUnit.MILLISECONDS);
33 | }
34 | }
35 |
--------------------------------------------------------------------------------
/edge-controller/src/main/java/cn/edu/scut/thread/UserRunnable.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut.thread;
2 |
3 | import cn.edu.scut.bean.EdgeNode;
4 | import cn.edu.scut.service.EdgeConfigService;
5 | import cn.edu.scut.service.EdgeNodeService;
6 | import cn.edu.scut.service.TaskService;
7 | import lombok.Setter;
8 | import lombok.extern.slf4j.Slf4j;
9 | import org.apache.commons.math3.random.RandomGenerator;
10 | import org.springframework.beans.factory.annotation.Autowired;
11 | import org.springframework.beans.factory.annotation.Value;
12 | import org.springframework.context.annotation.Scope;
13 | import org.springframework.stereotype.Component;
14 | import org.springframework.web.client.RestTemplate;
15 |
16 | import java.util.List;
17 |
18 | @Setter
19 | @Slf4j
20 | @Scope("prototype")
21 | @Component
22 | public class UserRunnable implements Runnable {
23 |
24 | @Autowired
25 | private RestTemplate restTemplate;
26 |
27 | @Autowired
28 | private RandomGenerator taskRandomGenerator;
29 |
30 | @Autowired
31 | private EdgeConfigService edgeConfigService;
32 |
33 | @Autowired
34 | private EdgeNodeService edgeNodeService;
35 |
36 | @Autowired
37 | private TaskService taskService;
38 |
39 | private int timeSlot = 0;
40 |
41 | @Value("${edgeComputing.episodeLimit}")
42 | private int episodeLimit;
43 |
44 | @Override
45 | public void run() {
46 | timeSlot += 1;
47 | if (timeSlot > episodeLimit + 1) {
48 | return;
49 | }
50 | try {
51 | List edgeNodes = edgeNodeService.list();
52 | for (EdgeNode edgeNode : edgeNodes) {
53 | boolean flag = taskRandomGenerator.nextDouble() < edgeNode.getTaskRate();
54 | if (flag) {
55 | var task = edgeConfigService.generateTask(edgeNode.getName());
56 | task.setTimeSlot(timeSlot);
57 | taskService.save(task);
58 | String url = String.format("http://%s/user/task", edgeNode.getName());
59 | restTemplate.postForObject(url, task, String.class);
60 | } else {
61 | var task = edgeConfigService.generateEmptyTask(edgeNode.getName());
62 | task.setTimeSlot(timeSlot);
63 | taskService.save(task);
64 | }
65 | }
66 | } catch (Exception e) {
67 | e.printStackTrace();
68 | }
69 | }
70 | }
--------------------------------------------------------------------------------
/edge-controller/src/main/java/cn/edu/scut/utils/SpringBeanUtils.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut.utils;
2 |
3 | import org.springframework.context.ApplicationContext;
4 |
5 | public class SpringBeanUtils {
6 | public static ApplicationContext applicationContext;
7 |
8 | public static void setApplicationContext(ApplicationContext applicationContext) {
9 | SpringBeanUtils.applicationContext = applicationContext;
10 | }
11 | }
12 |
--------------------------------------------------------------------------------
/edge-controller/src/main/resources/META-INF/MANIFEST.MF:
--------------------------------------------------------------------------------
1 | Manifest-Version: 1.0
2 | Main-Class: cn.edu.scut.ControllerApp
3 |
4 |
--------------------------------------------------------------------------------
/edge-controller/src/main/resources/bootstrap.yaml:
--------------------------------------------------------------------------------
1 | spring:
2 | application:
3 | name: edge-controller
4 | profiles:
5 | # active: edge-computing,mappo
6 | # active: edge-computing,masac
7 | # active: edge-computing,reliability-two-choice
8 | # active: edge-computing,reactive
9 | active: edge-computing,random
10 | # active: edge-computing,local
11 | cloud:
12 | nacos:
13 | discovery:
14 | server-addr: 222.201.187.50:30848
15 | watch-delay: 3000
16 | config:
17 | server-addr: 222.201.187.50:30848
18 | prefix: application
19 | file-extension: yaml
--------------------------------------------------------------------------------
/edge-experiment/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM lhc-ubuntu-20:v1.1
2 | LABEL maintainer=hongcai
3 | ENV PATH /root/lhc_dev/jdk-17.0.5/bin:${PATH}
4 | COPY target/edge-experiment-1.0-SNAPSHOT.jar /root/app.jar
5 | WORKDIR /root
6 | ENTRYPOINT ["sh","-c","java -jar app.jar"]
--------------------------------------------------------------------------------
/edge-experiment/README.md:
--------------------------------------------------------------------------------
1 | run the experiment.
2 |
3 | - online heuristic.
4 | - online DRL.
5 | - offline DRL.
6 | - offline to online DRL.
--------------------------------------------------------------------------------
/edge-experiment/pom.xml:
--------------------------------------------------------------------------------
1 |
2 |
5 | 4.0.0
6 |
7 | cn.edu.scut
8 | edge-computing
9 | 1.0-SNAPSHOT
10 |
11 |
12 | edge-experiment
13 |
14 |
15 | 17
16 | 17
17 | UTF-8
18 |
19 |
20 |
21 |
22 |
23 | mysql
24 | mysql-connector-java
25 |
26 |
27 | com.alibaba
28 | druid-spring-boot-starter
29 |
30 |
31 | com.baomidou
32 | mybatis-plus-boot-starter
33 |
34 |
35 |
36 | cn.edu.scut
37 | edge-algorithm
38 | 1.0-SNAPSHOT
39 |
40 |
41 |
42 | com.alibaba.cloud
43 | spring-cloud-starter-alibaba-nacos-discovery
44 |
45 |
46 | com.alibaba.cloud
47 | spring-cloud-starter-alibaba-nacos-config
48 |
49 |
50 | org.springframework.cloud
51 | spring-cloud-starter-bootstrap
52 |
53 |
54 |
55 |
56 | org.springframework.boot
57 | spring-boot-starter-web
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 | org.springframework.boot
75 | spring-boot-starter-logging
76 |
77 |
78 |
79 |
80 | org.projectlombok
81 | lombok
82 | true
83 |
84 |
85 |
86 |
87 | org.springframework.boot
88 | spring-boot-starter-test
89 | test
90 |
91 |
92 |
93 |
94 | org.springframework.cloud
95 | spring-cloud-loadbalancer
96 |
97 |
98 |
99 |
100 | cn.edu.scut
101 | edge-api
102 | 1.0-SNAPSHOT
103 |
104 |
105 |
106 |
107 | org.apache.commons
108 | commons-math3
109 |
110 |
111 |
112 |
113 | org.apache.hadoop
114 | hadoop-common
115 |
116 |
117 | org.slf4j
118 | slf4j-api
119 |
120 |
121 | org.slf4j
122 | slf4j-reload4j
123 |
124 |
125 |
126 |
127 | org.apache.hadoop
128 | hadoop-hdfs
129 |
130 |
131 | org.apache.hadoop
132 | hadoop-hdfs-client
133 |
134 |
135 |
136 |
137 | tech.tablesaw
138 | tablesaw-core
139 | LATEST
140 |
141 |
142 | tech.tablesaw
143 | tablesaw-jsplot
144 | 0.43.1
145 |
146 |
147 |
148 |
149 |
150 |
151 | org.springframework.boot
152 | spring-boot-maven-plugin
153 |
154 | true
155 | true
156 |
157 |
158 |
159 |
160 | repackage
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
--------------------------------------------------------------------------------
/edge-experiment/src/main/java/cn/edu/scut/ExperimentApp.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut;
2 |
3 | import cn.edu.scut.runner.*;
4 | import cn.edu.scut.util.SpringBeanUtils;
5 | import org.springframework.boot.SpringApplication;
6 | import org.springframework.boot.autoconfigure.SpringBootApplication;
7 |
8 | @SpringBootApplication
9 | public class ExperimentApp {
10 | public static void main(String[] args) {
11 | var context = SpringApplication.run(ExperimentApp.class, args);
12 | SpringBeanUtils.setApplicationContext(context);
13 | var env = context.getEnvironment();
14 | var runnerType = env.getProperty("edgeComputing.runner", String.class);
15 | assert runnerType != null;
16 | var runner = switch (runnerType) {
17 | case "rl-online" -> context.getBean(OnlineMARLTrainingRunner.class);
18 | case "rl-offline" -> context.getBean(OfflineMARLTrainingRunner.class);
19 | case "rl-test" -> context.getBean(OnlineMARLTestingRunner.class);
20 | case "heuristic" -> context.getBean(OnlineHeuristicTestRunner.class);
21 | case "heuristic-data" -> context.getBean(OnlineHeuristicDataRunner.class);
22 | default -> throw new RuntimeException("error in runner type.");
23 | };
24 | // waiting for edge-nodes and edge-controller start.
25 | try {
26 | Thread.sleep(5000);
27 | } catch (InterruptedException e) {
28 | throw new RuntimeException(e);
29 | }
30 | runner.run();
31 | }
32 | }
--------------------------------------------------------------------------------
/edge-experiment/src/main/java/cn/edu/scut/config/HadoopConfig.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut.config;
2 |
3 | import org.apache.hadoop.fs.FileSystem;
4 | import org.springframework.beans.factory.annotation.Value;
5 | import org.springframework.context.annotation.Bean;
6 | import org.springframework.context.annotation.Configuration;
7 |
8 | import java.io.IOException;
9 | import java.net.URI;
10 | import java.net.URISyntaxException;
11 |
12 | @Configuration
13 | public class HadoopConfig {
14 |
15 | @Value("${hadoop.hdfs.url}")
16 | private String hdfsUrl;
17 |
18 | @Bean
19 | public FileSystem fileSystem() {
20 | URI uri;
21 | try {
22 | uri = new URI(hdfsUrl);
23 | } catch (URISyntaxException e) {
24 | throw new RuntimeException(e);
25 | }
26 | var configuration = new org.apache.hadoop.conf.Configuration();
27 | String user = "hongcai";
28 | FileSystem fileSystem;
29 | try {
30 | fileSystem = FileSystem.get(uri, configuration, user);
31 | } catch (IOException | InterruptedException e) {
32 | throw new RuntimeException(e);
33 | }
34 | return fileSystem;
35 | }
36 | }
37 |
--------------------------------------------------------------------------------
/edge-experiment/src/main/java/cn/edu/scut/config/SpringConfig.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut.config;
2 |
3 | import org.springframework.beans.factory.annotation.Value;
4 | import org.springframework.cloud.client.loadbalancer.LoadBalanced;
5 | import org.springframework.context.annotation.Bean;
6 | import org.springframework.context.annotation.Configuration;
7 | import org.springframework.web.client.RestTemplate;
8 |
9 | import java.util.Random;
10 |
11 | @Configuration
12 | public class SpringConfig {
13 | @Value("${edgeComputing.seed}")
14 | Integer seed;
15 |
16 | @Bean
17 | @LoadBalanced
18 | RestTemplate restTemplate() {
19 | return new RestTemplate();
20 | }
21 |
22 | @Bean
23 | public Random schedulerRandom() {
24 | return new Random(seed);
25 | }
26 | }
--------------------------------------------------------------------------------
/edge-experiment/src/main/java/cn/edu/scut/controller/ModelController.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut.controller;
2 |
3 | import cn.edu.scut.agent.IMAAgent;
4 | import cn.edu.scut.service.TransitionService;
5 | import lombok.extern.slf4j.Slf4j;
6 | import org.springframework.beans.factory.annotation.Autowired;
7 | import org.springframework.web.bind.annotation.GetMapping;
8 | import org.springframework.web.bind.annotation.PathVariable;
9 | import org.springframework.web.bind.annotation.RestController;
10 |
11 | @RestController
12 | @Slf4j
13 | public class ModelController {
14 |
15 |
16 | @Autowired(required = false)
17 | private IMAAgent agent;
18 |
19 | @Autowired
20 | private TransitionService transitionService;
21 |
22 | @GetMapping("/action/{taskId}")
23 | public Integer action(@PathVariable("taskId") Long taskId) {
24 | var availAction = transitionService.getAvailAction(taskId);
25 | var state = transitionService.getState(taskId);
26 | int i = agent.selectAction(state, availAction, false) + 1;
27 | return i;
28 | }
29 | }
30 |
--------------------------------------------------------------------------------
/edge-experiment/src/main/java/cn/edu/scut/runner/IRunner.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut.runner;
2 |
3 | public interface IRunner {
4 | void run();
5 | }
6 |
--------------------------------------------------------------------------------
/edge-experiment/src/main/java/cn/edu/scut/runner/OfflineMARLTrainingRunner.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut.runner;
2 |
3 | import cn.edu.scut.agent.IMAAgent;
4 | import cn.edu.scut.agent.MABuffer;
5 | import cn.edu.scut.service.*;
6 | import cn.edu.scut.util.DateTimeUtils;
7 | import lombok.extern.slf4j.Slf4j;
8 | import org.springframework.beans.factory.annotation.Autowired;
9 | import org.springframework.beans.factory.annotation.Value;
10 | import org.springframework.context.annotation.Lazy;
11 | import org.springframework.stereotype.Service;
12 |
13 | import java.util.ArrayList;
14 |
15 | @Lazy
16 | @Service
17 | @Slf4j
18 | public class OfflineMARLTrainingRunner implements IRunner {
19 |
20 | @Value("${rl.name}")
21 | private String rlName;
22 |
23 | @Autowired
24 | private IMAAgent agent;
25 |
26 | @Autowired
27 | private TaskService taskService;
28 |
29 | @Autowired
30 | private LinkService linkService;
31 |
32 | @Autowired
33 | private EdgeNodeService edgeNodeService;
34 |
35 | @Autowired
36 | private PlotService plotService;
37 |
38 | @Autowired
39 | private RunnerService runnerService;
40 |
41 | @Autowired
42 | private MABuffer buffer;
43 |
44 | @Value("${rl.test-frequency}")
45 | private int testFrequency;
46 |
47 | @Value("${rl.training-time}")
48 | private int trainingTime;
49 |
50 | @Value("${rl.buffer-path}")
51 | private String bufferPath;
52 |
53 | @Value("${edgeComputing.flag}")
54 | private String flag;
55 |
56 | public void run() {
57 | log.info("========================");
58 | log.info("run rl-offline runner.");
59 | log.info("========================");
60 |
61 | // remove data
62 | linkService.remove(null);
63 | edgeNodeService.remove(null);
64 | taskService.remove(null);
65 | // init edge node and link
66 | runnerService.init();
67 | buffer.loadHdfs(bufferPath);
68 |
69 |
70 | log.info("store communication information start!");
71 | runnerService.run();
72 | taskService.remove(null);
73 | log.info("store communication information end!");
74 |
75 | var dateFlag = DateTimeUtils.getFlag();
76 | var flag = dateFlag + "@" + rlName + "@" + this.flag;
77 |
78 | var episodes = new ArrayList();
79 | var successRates = new ArrayList();
80 | for (int i = 0; i <= trainingTime; i++) {
81 | if (i % testFrequency == 0) {
82 | // update model
83 | agent.saveHdfsModel(flag);
84 | runnerService.updateModelByHdfs(flag);
85 | // performance
86 | double successRate = runnerService.test();
87 | log.info("train time: {}, test success rate: {}", i, successRate);
88 | episodes.add((double) i);
89 | successRates.add(successRate);
90 | }
91 | agent.train();
92 | }
93 | plotService.plot(episodes, successRates, flag);
94 | }
95 | }
96 |
--------------------------------------------------------------------------------
/edge-experiment/src/main/java/cn/edu/scut/runner/OnlineHeuristicDataRunner.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut.runner;
2 |
3 | import cn.edu.scut.agent.MABuffer;
4 | import cn.edu.scut.service.*;
5 | import cn.edu.scut.util.DateTimeUtils;
6 | import lombok.extern.slf4j.Slf4j;
7 | import org.springframework.beans.factory.annotation.Autowired;
8 | import org.springframework.beans.factory.annotation.Value;
9 | import org.springframework.context.annotation.Lazy;
10 | import org.springframework.stereotype.Service;
11 |
12 | import java.util.ArrayList;
13 | import java.util.stream.Collectors;
14 |
15 | @Service
16 | @Slf4j
17 | @Lazy
18 | public class OnlineHeuristicDataRunner implements IRunner {
19 | @Autowired
20 | private TaskService taskService;
21 |
22 | @Autowired
23 | private LinkService linkService;
24 |
25 | @Autowired
26 | private EdgeNodeService edgeNodeService;
27 |
28 | @Autowired
29 | private PlotService plotService;
30 |
31 | @Autowired
32 | private RunnerService runnerService;
33 |
34 | @Value("${edgeComputing.episodeNumber}")
35 | private int episodeNumber;
36 |
37 | @Value("${heuristic.name}")
38 | private String name;
39 |
40 | @Value("${edgeComputing.flag}")
41 | private String flag;
42 |
43 | @Autowired
44 | private MABuffer buffer;
45 |
46 | public void run() {
47 | log.info("=============================");
48 | log.info("run heuristic-data runner");
49 | log.info("=============================");
50 |
51 | linkService.remove(null);
52 | edgeNodeService.remove(null);
53 | taskService.remove(null);
54 | runnerService.init();
55 |
56 | log.info("store communication information start!");
57 | runnerService.run();
58 | taskService.remove(null);
59 | log.info("store communication information end!");
60 |
61 | var dateFlag = DateTimeUtils.getFlag();
62 | var flag = dateFlag + "@" + name + "@" + this.flag;
63 |
64 | var episodes = new ArrayList();
65 | var successRates = new ArrayList();
66 | for (int currentEpisode = 1; currentEpisode <= episodeNumber; currentEpisode++) {
67 | runnerService.run();
68 | double successRate = taskService.getSuccessRate();
69 | log.info("running time: {}, success rate: {}", currentEpisode, successRate);
70 | episodes.add((double) currentEpisode);
71 | successRates.add(successRate);
72 | runnerService.addData();
73 | taskService.remove(null);
74 | }
75 | plotService.plot(episodes, successRates, flag);
76 | buffer.saveHdfs(flag);
77 | var averageSuccessRate = successRates.stream().collect(Collectors.averagingDouble(x -> x));
78 | log.info("average success rate: {}", averageSuccessRate);
79 | }
80 | }
81 |
--------------------------------------------------------------------------------
/edge-experiment/src/main/java/cn/edu/scut/runner/OnlineHeuristicTestRunner.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut.runner;
2 |
3 | import cn.edu.scut.service.*;
4 | import lombok.extern.slf4j.Slf4j;
5 | import org.springframework.beans.factory.annotation.Autowired;
6 | import org.springframework.beans.factory.annotation.Value;
7 | import org.springframework.context.annotation.Lazy;
8 | import org.springframework.stereotype.Service;
9 |
10 | @Service
11 | @Slf4j
12 | @Lazy
13 | public class OnlineHeuristicTestRunner implements IRunner {
14 | @Autowired
15 | private TaskService taskService;
16 |
17 | @Autowired
18 | private LinkService linkService;
19 |
20 | @Autowired
21 | private EdgeNodeService edgeNodeService;
22 |
23 | @Autowired
24 | private PlotService plotService;
25 |
26 | @Autowired
27 | private RunnerService runnerService;
28 |
29 | @Value("${edgeComputing.episodeNumber}")
30 | private int episodeNumber;
31 |
32 | @Value("${heuristic.name}")
33 | private String name;
34 |
35 | @Value("${edgeComputing.flag}")
36 | private String flag;
37 |
38 | @Value("${edgeComputing.taskSeed}")
39 | private int taskSeed;
40 |
41 | public void run() {
42 | log.info(" task seed : {}", taskSeed);
43 | log.info("=============================");
44 | log.info("run heuristic runner.");
45 | log.info("=============================");
46 | linkService.remove(null);
47 | edgeNodeService.remove(null);
48 | taskService.remove(null);
49 | runnerService.init();
50 |
51 | log.info("store communication information start!");
52 | runnerService.run();
53 | taskService.remove(null);
54 | log.info("store communication information end!");
55 |
56 | runnerService.test();
57 | }
58 | }
59 |
--------------------------------------------------------------------------------
/edge-experiment/src/main/java/cn/edu/scut/runner/OnlineMARLTestingRunner.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut.runner;
2 |
3 | import cn.edu.scut.service.EdgeNodeService;
4 | import cn.edu.scut.service.LinkService;
5 | import cn.edu.scut.service.RunnerService;
6 | import cn.edu.scut.service.TaskService;
7 | import lombok.extern.slf4j.Slf4j;
8 | import org.springframework.beans.factory.annotation.Autowired;
9 | import org.springframework.beans.factory.annotation.Value;
10 | import org.springframework.context.annotation.Lazy;
11 | import org.springframework.stereotype.Service;
12 |
13 |
14 | @Service
15 | @Lazy
16 | @Slf4j
17 | public class OnlineMARLTestingRunner implements IRunner {
18 |
19 | @Autowired
20 | private TaskService taskService;
21 |
22 | @Autowired
23 | private LinkService linkService;
24 |
25 | @Autowired
26 | private EdgeNodeService edgeNodeService;
27 |
28 | @Autowired
29 | private RunnerService runnerService;
30 |
31 | @Value("${rl.model-flag}")
32 | private String modelFlag;
33 |
34 | @Override
35 | public void run() {
36 | log.info("========================");
37 | log.info("run rl-test runner!");
38 | log.info("========================");
39 |
40 | // remove data
41 | linkService.remove(null);
42 | edgeNodeService.remove(null);
43 | taskService.remove(null);
44 | // init edge node and link
45 | runnerService.init();
46 | // update model
47 | runnerService.updateModelByHdfs(modelFlag);
48 |
49 | log.info("store communication information start!");
50 | runnerService.run();
51 | taskService.remove(null);
52 | log.info("store communication information end!");
53 |
54 | runnerService.test();
55 | }
56 | }
57 |
--------------------------------------------------------------------------------
/edge-experiment/src/main/java/cn/edu/scut/runner/OnlineMARLTrainingRunner.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut.runner;
2 |
3 | import cn.edu.scut.agent.IMAAgent;
4 | import cn.edu.scut.service.*;
5 | import cn.edu.scut.util.DateTimeUtils;
6 | import lombok.extern.slf4j.Slf4j;
7 | import org.springframework.beans.factory.annotation.Autowired;
8 | import org.springframework.beans.factory.annotation.Value;
9 | import org.springframework.context.annotation.Lazy;
10 | import org.springframework.stereotype.Component;
11 |
12 | import java.util.ArrayList;
13 |
14 | /**
15 | * online
16 | * offline-to-online
17 | */
18 | @Component
19 | @Slf4j
20 | @Lazy
21 | public class OnlineMARLTrainingRunner implements IRunner {
22 |
23 | @Value("${edgeComputing.episodeNumber}")
24 | private int episodeNumber;
25 |
26 | @Value("${rl.name}")
27 | private String rlName;
28 |
29 | @Autowired
30 | private IMAAgent agent;
31 |
32 | @Autowired
33 | private TaskService taskService;
34 |
35 | @Autowired
36 | private LinkService linkService;
37 |
38 | @Autowired
39 | private EdgeNodeService edgeNodeService;
40 |
41 | @Autowired
42 | private PlotService plotService;
43 |
44 | @Autowired
45 | private RunnerService runnerService;
46 |
47 | @Value("${rl.use-trained-model}")
48 | private boolean useTrainedModel;
49 |
50 | @Value("${rl.model-flag}")
51 | private String modelFlag;
52 |
53 | @Value("${edgeComputing.flag}")
54 | private String flag;
55 |
56 | @Value("${edgeComputing.testFrequency}")
57 | private int testFrequency;
58 |
59 | public void run() {
60 | log.info("========================");
61 | log.info("run rl-online runner!");
62 | log.info("========================");
63 | // remove data
64 | linkService.remove(null);
65 | edgeNodeService.remove(null);
66 | taskService.remove(null);
67 | // init edge node and link
68 | runnerService.init();
69 |
70 | // offline to online
71 | if (useTrainedModel) {
72 | agent.loadHdfsModel(modelFlag);
73 | log.info("load model from hdfs.");
74 | }
75 | var dateFlag = DateTimeUtils.getFlag();
76 | var flag = dateFlag + "@" + rlName + "@" + this.flag;
77 |
78 | log.info("store communication information start!");
79 | runnerService.run();
80 | taskService.remove(null);
81 | log.info("store communication information end!");
82 |
83 | var episodes = new ArrayList();
84 | var successRates = new ArrayList();
85 | for (int currentEpisode = 0; currentEpisode <= episodeNumber; currentEpisode++) {
86 | // update model
87 | agent.saveHdfsModel(flag);
88 | runnerService.updateModelByHdfs(flag);
89 | // performance
90 | if (currentEpisode % testFrequency == 0) {
91 | log.info("start to test.");
92 | var successRate = runnerService.test();
93 | episodes.add((double) currentEpisode);
94 | successRates.add(successRate);
95 | log.info("end test.");
96 | }
97 | // run episode
98 | runnerService.run();
99 | var successRate = taskService.getSuccessRate();
100 | log.info("training episode {}, success rate {}", currentEpisode, successRate);
101 | // process data
102 | runnerService.addData();
103 | // training
104 | agent.train();
105 | // remove data
106 | taskService.remove(null);
107 | }
108 | agent.saveHdfsModel(flag);
109 | // plot
110 | plotService.plot(episodes, successRates, flag);
111 | }
112 | }
113 |
--------------------------------------------------------------------------------
/edge-experiment/src/main/java/cn/edu/scut/service/PlotService.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut.service;
2 |
3 | import cn.edu.scut.util.ArrayUtils;
4 | import cn.edu.scut.util.PlotUtils;
5 | import lombok.extern.slf4j.Slf4j;
6 | import org.apache.hadoop.fs.FileSystem;
7 | import org.apache.hadoop.fs.Path;
8 | import org.springframework.beans.factory.annotation.Autowired;
9 | import org.springframework.beans.factory.annotation.Value;
10 | import org.springframework.stereotype.Service;
11 | import tech.tablesaw.plotly.Plot;
12 |
13 | import java.io.IOException;
14 | import java.io.UncheckedIOException;
15 | import java.nio.file.Files;
16 | import java.nio.file.Paths;
17 | import java.util.ArrayList;
18 |
19 | @Service
20 | @Slf4j
21 | public class PlotService {
22 |
23 | @Autowired
24 | private FileSystem fileSystem;
25 |
26 | @Value("${hadoop.hdfs.url}")
27 | private String hdfsUrl;
28 |
29 | public void plot(ArrayList episodes, ArrayList successRates, String flag) {
30 | var episodes_ = ArrayUtils.toDoubleArray(episodes);
31 | var successRates_ = ArrayUtils.toDoubleArray(successRates);
32 | var figure = PlotUtils.plot(new double[][]{episodes_}, new double[][]{successRates_}, new String[]{"rl"}, "episode", "success rate");
33 | var path = Paths.get("results", "figure", flag + ".html");
34 | try {
35 | try {
36 | Files.createDirectories(path.getParent());
37 | } catch (IOException var2) {
38 | throw new UncheckedIOException(var2);
39 | }
40 | var file = path.toFile();
41 | Plot.show(figure, file);
42 | } catch (Exception e) {
43 | log.info("browser not support!");
44 | }
45 | try {
46 | fileSystem.copyFromLocalFile(true, true, new Path(path.toString()), new Path(hdfsUrl + "/" + path));
47 | } catch (IOException e) {
48 | throw new RuntimeException(e);
49 | }
50 | }
51 | }
52 |
--------------------------------------------------------------------------------
/edge-experiment/src/main/java/cn/edu/scut/service/RunnerService.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut.service;
2 |
3 |
4 | import cn.edu.scut.agent.MABuffer;
5 | import cn.edu.scut.bean.Task;
6 | import cn.edu.scut.util.MathUtils;
7 | import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
8 | import lombok.extern.slf4j.Slf4j;
9 | import org.springframework.beans.factory.annotation.Autowired;
10 | import org.springframework.beans.factory.annotation.Value;
11 | import org.springframework.context.annotation.Lazy;
12 | import org.springframework.core.io.FileSystemResource;
13 | import org.springframework.http.HttpEntity;
14 | import org.springframework.http.HttpHeaders;
15 | import org.springframework.http.MediaType;
16 | import org.springframework.http.client.MultipartBodyBuilder;
17 | import org.springframework.stereotype.Service;
18 | import org.springframework.web.client.RestTemplate;
19 |
20 | import java.util.ArrayList;
21 | import java.util.Map;
22 |
23 | @Service
24 | @Slf4j
25 | public class RunnerService {
26 |
27 | @Autowired
28 | private RestTemplate restTemplate;
29 |
30 | @Autowired
31 | private TaskService taskService;
32 |
33 | @Value("${edgeComputing.edgeNodeNumber}")
34 | private int agentNumber;
35 |
36 | @Value("${edgeComputing.episodeLimit}")
37 | private int episodeLimit;
38 |
39 | @Value("${edgeComputing.timeSlot}")
40 | private int timeSlotLen;
41 |
42 | // heuristic-test
43 | @Value("${rl.state-shape:0}")
44 | private int stateShape;
45 |
46 | @Value("${edgeComputing.testNumber}")
47 | private int testNumber;
48 |
49 | // heuristic-test
50 | @Lazy
51 | @Autowired(required = false)
52 | private MABuffer buffer;
53 |
54 | @Autowired
55 | private TransitionService transitionService;
56 |
57 | public void init() {
58 | var controllerUrl = "http://edge-controller";
59 | var generateMessage = restTemplate.getForObject(controllerUrl + "/generate", String.class);
60 | log.info("generate edge nodes configuration: {}", generateMessage);
61 | var initMessage = restTemplate.getForObject(controllerUrl + "/init", String.class);
62 | log.info("init edge nodes: {}", initMessage);
63 | }
64 |
65 | public void run() {
66 | var controllerUrl = "http://edge-controller";
67 | var restartMessage = restTemplate.getForObject(controllerUrl + "/restart", String.class);
68 | log.info("restart experiment : {}", restartMessage);
69 | while (true) {
70 | try {
71 | Thread.sleep(1000);
72 | } catch (InterruptedException e) {
73 | log.error("{}", e.getMessage());
74 | }
75 | var count = taskService.count(new QueryWrapper().ne("status", "NEW"));
76 | if (count >= (episodeLimit + 1) * agentNumber) {
77 | break;
78 | }
79 | }
80 | var stopMessage = restTemplate.getForObject(controllerUrl + "/stop", String.class);
81 | log.info("stop experiment: {}", stopMessage);
82 | }
83 |
84 | public double test() {
85 | var list = new ArrayList();
86 | for (int i = 1; i <= testNumber; i++) {
87 | run();
88 | var successRate = taskService.getSuccessRate();
89 | log.info("test {}, success rate : {}", i, successRate);
90 | list.add(successRate);
91 | taskService.remove(null);
92 | }
93 | var avg = MathUtils.avg(list);
94 | var std = MathUtils.std(list);
95 | log.info("avg success rate: {}", avg);
96 | log.info("std success rate: {}", std);
97 | return avg;
98 | }
99 |
100 | public void updateModelByHdfs(String flag) {
101 | for (int i = 1; i < agentNumber; i++) {
102 | String edgeNodeId = String.format("edge-node-%d", i);
103 | restTemplate.getForObject("http://" + edgeNodeId + "/updateParam/{flag}", String.class, Map.of("flag", flag));
104 | }
105 | }
106 |
107 | public void updateModelByFileTransfer(String flag) {
108 | for (int i = 1; i < agentNumber; i++) {
109 | String edgeNodeId = String.format("edge-node-%d", i);
110 | var headers = new HttpHeaders();
111 | headers.setContentType(MediaType.MULTIPART_FORM_DATA);
112 | var multipartBodyBuilder = new MultipartBodyBuilder();
113 | var file1 = new FileSystemResource("results/model/" + flag + "/actor.param");
114 | multipartBodyBuilder.part("actor", file1, MediaType.TEXT_PLAIN);
115 | var multipartBody = multipartBodyBuilder.build();
116 | var httpEntity = new HttpEntity<>(multipartBody, headers);
117 | try {
118 | restTemplate.postForObject("http://" + edgeNodeId + "/updateModel", httpEntity, String.class);
119 | } catch (Exception e) {
120 | log.error(e.getMessage());
121 | throw new RuntimeException(e);
122 | }
123 | }
124 | }
125 |
126 | public void addData() {
127 | var teamRewards = new float[episodeLimit * 2];
128 | var states = new float[episodeLimit][agentNumber][stateShape];
129 | var actions = new int[episodeLimit][agentNumber][1];
130 | var availActions = new int[episodeLimit][agentNumber][1];
131 | var rewards = new float[episodeLimit][agentNumber][1];
132 | var nextStates = new float[episodeLimit][agentNumber][stateShape];
133 |
134 | for (int i = 1; i <= episodeLimit; i++) {
135 | for (int j = 1; j <= agentNumber; j++) {
136 | var source = String.format("edge-node-%d", j);
137 | var task = taskService.getOne(new QueryWrapper().eq("source", source).eq("time_slot", i));
138 | var nextTask = taskService.getOne(new QueryWrapper().eq("source", source).eq("time_slot", i + 1));
139 | var state = transitionService.getState(task.getId());
140 | states[i - 1][j - 1] = state;
141 | var action = transitionService.getAction(task.getId());
142 | actions[i - 1][j - 1] = new int[]{action};
143 | var nextState = transitionService.getState(nextTask.getId());
144 | nextStates[i - 1][j - 1] = nextState;
145 | var availAction = transitionService.getAvailAction(task.getId());
146 | availActions[i - 1][j - 1] = availAction;
147 | var reward = transitionService.getReward(task.getId());
148 | long totalTime = task.getTransmissionWaitingTime() + task.getTransmissionTime() + task.getExecutionWaitingTime() + task.getExecutionTime();
149 | var endTimeSlot = task.getTimeSlot() + totalTime / timeSlotLen;
150 | int timeSlotIndex = (int) endTimeSlot - 1;
151 | teamRewards[timeSlotIndex] += reward;
152 | }
153 | }
154 | for (int i = 1; i <= episodeLimit; i++) {
155 | for (int j = 1; j <= agentNumber; j++) {
156 | rewards[i - 1][j - 1] = new float[]{teamRewards[i - 1]};
157 | }
158 | }
159 | for (int i = 1; i <= episodeLimit; i++) {
160 | buffer.insert(states[i - 1], actions[i - 1], availActions[i - 1], rewards[i - 1], nextStates[i - 1]);
161 | }
162 | }
163 | }
164 |
--------------------------------------------------------------------------------
/edge-experiment/src/main/java/cn/edu/scut/util/DateTimeUtils.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut.util;
2 |
3 | import java.time.LocalDateTime;
4 | import java.time.format.DateTimeFormatter;
5 |
6 | public class DateTimeUtils {
7 |
8 | public static String getFlag() {
9 | var dateTime = LocalDateTime.now();
10 | var dateTime2 = dateTime.plusHours(8);
11 | var formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd-hh_mm_ss");
12 | return formatter.format(dateTime2);
13 | }
14 | }
15 |
--------------------------------------------------------------------------------
/edge-experiment/src/main/java/cn/edu/scut/util/MathUtils.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut.util;
2 |
3 | import java.util.ArrayList;
4 |
5 | public class MathUtils {
6 | public static double avg(ArrayList data) {
7 | double sum = 0;
8 | for (Double val : data) {
9 | sum += val;
10 | }
11 | return sum / data.size();
12 | }
13 |
14 | public static double std(ArrayList data) {
15 | double res = 0;
16 | double avg = avg(data);
17 | for (Double val : data) {
18 | res += Math.pow((val - avg), 2);
19 | }
20 | res /= data.size() - 1;
21 | res = Math.sqrt(res);
22 | return res;
23 | }
24 | }
25 |
--------------------------------------------------------------------------------
/edge-experiment/src/main/java/cn/edu/scut/util/SpringBeanUtils.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut.util;
2 |
3 | import org.springframework.context.ApplicationContext;
4 |
5 | public class SpringBeanUtils {
6 | public static ApplicationContext applicationContext;
7 |
8 | public static void setApplicationContext(ApplicationContext applicationContext) {
9 | SpringBeanUtils.applicationContext = applicationContext;
10 | }
11 | }
12 |
--------------------------------------------------------------------------------
/edge-experiment/src/main/resources/bootstrap.yaml:
--------------------------------------------------------------------------------
1 | spring:
2 | application:
3 | name: edge-experiment
4 | profiles:
5 | # active: edge-computing,mappo
6 | # active: edge-computing,masac
7 | # active: edge-computing,reliability-two-choice
8 | # active: edge-computing,reactive
9 | active: edge-computing,random
10 | # active: edge-computing,local
11 | cloud:
12 | nacos:
13 | discovery:
14 | server-addr: 222.201.187.50:30848
15 | watch-delay: 3000
16 | config:
17 | server-addr: 222.201.187.50:30848
18 | prefix: application
19 | file-extension: yaml
20 | server:
21 | port: 9001
--------------------------------------------------------------------------------
/edge-experiment/test-results/buffer/test-buffer.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lhc0512/edge-computing/d13e6f16266b5b98d69c8e540f61ac60f6db6e31/edge-experiment/test-results/buffer/test-buffer.txt
--------------------------------------------------------------------------------
/edge-node/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM lhc-ubuntu-20:v1.1
2 | LABEL maintainer=hongcai
3 | ENV PATH /root/lhc_dev/jdk-17.0.5/bin:${PATH}
4 | COPY target/edge-node-1.0-SNAPSHOT.jar /root/app.jar
5 | WORKDIR /root
6 | ENTRYPOINT ["sh","-c","java -jar app.jar --spring_application_name=$spring_application_name"]
--------------------------------------------------------------------------------
/edge-node/pom.xml:
--------------------------------------------------------------------------------
1 |
2 |
5 |
6 | edge-computing
7 | cn.edu.scut
8 | 1.0-SNAPSHOT
9 |
10 | 4.0.0
11 |
12 | edge-node
13 |
14 |
15 | 17
16 | 17
17 |
18 |
19 |
20 |
21 | mysql
22 | mysql-connector-java
23 |
24 |
25 | com.alibaba
26 | druid-spring-boot-starter
27 |
28 |
29 | com.baomidou
30 | mybatis-plus-boot-starter
31 |
32 |
33 |
34 | com.alibaba.cloud
35 | spring-cloud-starter-alibaba-nacos-discovery
36 |
37 |
38 | com.alibaba.cloud
39 | spring-cloud-starter-alibaba-nacos-config
40 |
41 |
42 | org.springframework.cloud
43 | spring-cloud-starter-bootstrap
44 |
45 |
46 |
47 |
48 | org.springframework.boot
49 | spring-boot-starter-web
50 |
51 |
52 |
53 | org.springframework.boot
54 | spring-boot-starter-logging
55 |
56 |
57 |
58 |
59 | org.projectlombok
60 | lombok
61 | true
62 |
63 |
64 |
65 |
66 | org.springframework.boot
67 | spring-boot-starter-test
68 | test
69 |
70 |
71 |
72 |
73 | cn.edu.scut
74 | edge-api
75 | 1.0-SNAPSHOT
76 |
77 |
78 | org.springframework.cloud
79 | spring-cloud-loadbalancer
80 |
81 |
82 | org.apache.hadoop
83 | hadoop-common
84 |
85 |
86 | org.slf4j
87 | slf4j-api
88 |
89 |
90 | org.slf4j
91 | slf4j-reload4j
92 |
93 |
94 |
95 |
96 |
97 | org.apache.hadoop
98 | hadoop-hdfs
99 |
100 |
101 |
102 | org.apache.hadoop
103 | hadoop-hdfs-client
104 |
105 |
106 |
107 | cn.edu.scut
108 | edge-algorithm
109 | 1.0-SNAPSHOT
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 | org.springframework.boot
119 | spring-boot-maven-plugin
120 |
121 | true
122 | true
123 |
124 |
125 |
126 |
127 | repackage
128 |
129 |
130 |
131 |
132 |
133 |
134 |
--------------------------------------------------------------------------------
/edge-node/src/main/java/cn/edu/scut/EdgeNodeApp.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut;
2 |
3 | import cn.edu.scut.utils.SpringBeanUtils;
4 | import org.springframework.boot.SpringApplication;
5 | import org.springframework.boot.autoconfigure.SpringBootApplication;
6 | import org.springframework.cloud.client.discovery.EnableDiscoveryClient;
7 | import org.springframework.context.ConfigurableApplicationContext;
8 | import org.springframework.scheduling.annotation.EnableAsync;
9 |
10 | @SpringBootApplication
11 | @EnableDiscoveryClient
12 | @EnableAsync
13 | public class EdgeNodeApp {
14 | public static void main(String[] args) {
15 | ConfigurableApplicationContext context = SpringApplication.run(EdgeNodeApp.class, args);
16 | SpringBeanUtils.setApplicationContext(context);
17 | }
18 | }
--------------------------------------------------------------------------------
/edge-node/src/main/java/cn/edu/scut/bean/EdgeNodeSystem.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut.bean;
2 |
3 | import cn.edu.scut.queue.ExecutionQueue;
4 | import cn.edu.scut.queue.TransmissionQueue;
5 | import lombok.Getter;
6 | import lombok.Setter;
7 | import org.springframework.stereotype.Component;
8 |
9 | import java.util.Map;
10 |
11 | @Getter
12 | @Setter
13 | @Component
14 | public class EdgeNodeSystem {
15 | private EdgeNode edgeNode;
16 | private Map linkMap;
17 | private ExecutionQueue executionQueue;
18 | private Map transmissionQueueMap;
19 | // threshold
20 | private float executionQueueThreshold;
21 | }
22 |
--------------------------------------------------------------------------------
/edge-node/src/main/java/cn/edu/scut/config/Config.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut.config;
2 |
3 | import lombok.extern.slf4j.Slf4j;
4 | import org.springframework.beans.factory.annotation.Value;
5 | import org.springframework.cloud.client.loadbalancer.LoadBalanced;
6 | import org.springframework.context.annotation.Bean;
7 | import org.springframework.context.annotation.Configuration;
8 | import org.springframework.web.client.RestTemplate;
9 |
10 | import java.util.Random;
11 |
12 | @Configuration
13 | @Slf4j
14 | public class Config {
15 |
16 | @Value("${edgeComputing.reliabilitySeed}")
17 | Integer reliabilitySeed;
18 |
19 | @Value("${edgeComputing.schedulerSeed}")
20 | Integer schedulerSeed;
21 |
22 | @Value("${spring.application.name}")
23 | String name;
24 |
25 | @Bean
26 | public Random reliabilityRandom() {
27 | Integer id = Integer.parseInt(name.split("-")[2]);
28 | return new Random(reliabilitySeed + id);
29 | }
30 |
31 | @Bean
32 | public Random schedulerRandom() {
33 | Integer id = Integer.parseInt(name.split("-")[2]);
34 | return new Random(schedulerSeed + id);
35 | }
36 |
37 | @Bean
38 | @LoadBalanced
39 | RestTemplate restTemplate() {
40 | return new RestTemplate();
41 | }
42 | }
--------------------------------------------------------------------------------
/edge-node/src/main/java/cn/edu/scut/config/HadoopConfig.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut.config;
2 |
3 | import org.apache.hadoop.fs.FileSystem;
4 | import org.springframework.beans.factory.annotation.Value;
5 | import org.springframework.context.annotation.Bean;
6 | import org.springframework.context.annotation.Configuration;
7 |
8 | import java.io.IOException;
9 | import java.net.URI;
10 | import java.net.URISyntaxException;
11 |
12 | @Configuration
13 | public class HadoopConfig {
14 |
15 | @Value("${hadoop.hdfs.url}")
16 | private String hdfsUrl;
17 |
18 | @Bean
19 | public FileSystem fileSystem() {
20 | URI uri;
21 | try {
22 | uri = new URI(hdfsUrl);
23 | } catch (URISyntaxException e) {
24 | throw new RuntimeException(e);
25 | }
26 | var configuration = new org.apache.hadoop.conf.Configuration();
27 | String user = "hongcai";
28 | FileSystem fileSystem;
29 | try {
30 | fileSystem = FileSystem.get(uri, configuration, user);
31 | } catch (IOException | InterruptedException e) {
32 | throw new RuntimeException(e);
33 | }
34 | return fileSystem;
35 | }
36 | }
37 |
--------------------------------------------------------------------------------
/edge-node/src/main/java/cn/edu/scut/controller/CmdController.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut.controller;
2 |
3 | import cn.edu.scut.service.EdgeNodeSystemService;
4 | import org.springframework.web.bind.annotation.GetMapping;
5 | import org.springframework.web.bind.annotation.RequestMapping;
6 | import org.springframework.web.bind.annotation.RequestMethod;
7 | import org.springframework.web.bind.annotation.RestController;
8 |
9 | import javax.annotation.Resource;
10 |
11 | @RestController
12 | @RequestMapping(value = "/cmd", method = {RequestMethod.GET, RequestMethod.POST})
13 | public class CmdController {
14 |
15 | @Resource
16 | EdgeNodeSystemService edgeNodeSystemService;
17 |
18 | @GetMapping("/init")
19 | public String init(){
20 | edgeNodeSystemService.init();
21 | return "success";
22 | }
23 | }
24 |
--------------------------------------------------------------------------------
/edge-node/src/main/java/cn/edu/scut/controller/EdgeNodeController.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut.controller;
2 |
3 | import cn.edu.scut.bean.EdgeNode;
4 | import cn.edu.scut.bean.RatcVo;
5 | import cn.edu.scut.bean.Task;
6 | import cn.edu.scut.service.EdgeNodeSystemService;
7 | import cn.edu.scut.thread.ExecutionRunnable;
8 | import lombok.extern.apachecommons.CommonsLog;
9 | import org.springframework.web.bind.annotation.*;
10 |
11 | import javax.annotation.Resource;
12 |
13 | @RestController
14 | @RequestMapping(value = "/edgeNode", method = {RequestMethod.GET, RequestMethod.POST})
15 | @CommonsLog
16 | public class EdgeNodeController {
17 |
18 | @Resource
19 | private EdgeNodeSystemService edgeNodeSystemService;
20 |
21 | @PostMapping("/task")
22 | public String receiveEdgeNodeTask(@RequestBody Task task) {
23 | log.info("receive task from edge node: " + task.getSource());
24 | edgeNodeSystemService.processEdgeNodeTask(task);
25 | return "success";
26 | }
27 |
28 | @GetMapping("/queue")
29 | public Integer queue() {
30 | return edgeNodeSystemService.getEdgeNodeSystem().getExecutionQueue().getSize();
31 | }
32 |
33 | @GetMapping("/info")
34 | public EdgeNode info() {
35 | return edgeNodeSystemService.getEdgeNodeSystem().getEdgeNode();
36 | }
37 |
38 | @GetMapping("/waitingTime")
39 | public Long waitingTime() {
40 | var queue = edgeNodeSystemService.getEdgeNodeSystem().getExecutionQueue().getExecutor().getQueue();
41 | long waitingTime = 0;
42 | for (Runnable runnable : queue) {
43 | var r = (ExecutionRunnable) runnable;
44 | waitingTime += r.getTask().getExecutionTime();
45 | }
46 | return waitingTime;
47 | }
48 |
49 | @GetMapping("/ratc")
50 | public RatcVo ratc() {
51 | var queue = edgeNodeSystemService.getEdgeNodeSystem().getExecutionQueue().getExecutor().getQueue();
52 | long waitingTime = 0;
53 | for (Runnable runnable : queue) {
54 | var r = (ExecutionRunnable) runnable;
55 | waitingTime += r.getTask().getExecutionTime();
56 | }
57 | var res = new RatcVo();
58 | res.setWaitingTime(waitingTime);
59 | var edgeNode = edgeNodeSystemService.getEdgeNodeSystem().getEdgeNode();
60 | res.setExecutionFailureRate(edgeNode.getExecutionFailureRate());
61 | res.setCapacity(edgeNode.getCapacity());
62 | res.setEdgeId(edgeNode.getName());
63 | return res;
64 | }
65 |
66 | @GetMapping("/available")
67 | public Integer avail() {
68 | int queueSize = edgeNodeSystemService.getEdgeNodeSystem().getExecutionQueue().getSize();
69 | if (queueSize < edgeNodeSystemService.getEdgeNodeSystem().getExecutionQueueThreshold()) {
70 | return 1;
71 | } else {
72 | return 0;
73 | }
74 | }
75 | }
--------------------------------------------------------------------------------
/edge-node/src/main/java/cn/edu/scut/controller/FileController.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut.controller;
2 |
3 | import lombok.extern.slf4j.Slf4j;
4 | import org.springframework.http.MediaType;
5 | import org.springframework.util.StreamUtils;
6 | import org.springframework.web.bind.annotation.PostMapping;
7 | import org.springframework.web.bind.annotation.RestController;
8 |
9 | import javax.servlet.ServletException;
10 | import javax.servlet.http.HttpServletRequest;
11 | import javax.servlet.http.Part;
12 | import java.io.IOException;
13 | import java.io.InputStream;
14 | import java.nio.file.Files;
15 | import java.nio.file.Paths;
16 |
17 | @RestController
18 | @Slf4j
19 | public class FileController {
20 |
21 | @PostMapping(value = "/upload", consumes = MediaType.MULTIPART_FORM_DATA_VALUE, produces = MediaType.TEXT_PLAIN_VALUE)
22 | public String upload(HttpServletRequest request) throws IOException, ServletException {
23 | for (Part part : request.getParts()) {
24 | log.info("content type: {}", part.getContentType()); //text/plain
25 | log.info("parameter name: {}", part.getName()); // actor
26 | log.info("file name: {}", part.getSubmittedFileName()); // actor.param
27 | log.info("file size: {}", part.getSize());
28 | log.info("save file");
29 | InputStream inputStream = part.getInputStream();
30 | var path = Paths.get("results/model/edge-node-1/"+ part.getSubmittedFileName());
31 | Files.createDirectories(path.getParent());
32 | Files.write(path, StreamUtils.copyToByteArray(inputStream));
33 | }
34 | return "success";
35 | }
36 | }
37 |
--------------------------------------------------------------------------------
/edge-node/src/main/java/cn/edu/scut/controller/ModelController.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut.controller;
2 |
3 | import cn.edu.scut.agent.IMAAgent;
4 | import lombok.extern.slf4j.Slf4j;
5 | import org.springframework.beans.factory.annotation.Autowired;
6 | import org.springframework.web.bind.annotation.GetMapping;
7 | import org.springframework.web.bind.annotation.PathVariable;
8 | import org.springframework.web.bind.annotation.PostMapping;
9 | import org.springframework.web.bind.annotation.RestController;
10 |
11 | import javax.servlet.ServletException;
12 | import javax.servlet.http.HttpServletRequest;
13 | import javax.servlet.http.Part;
14 | import java.io.IOException;
15 |
16 | @RestController
17 | @Slf4j
18 | public class ModelController {
19 |
20 | // heuristic
21 | @Autowired(required = false)
22 | private IMAAgent agent;
23 |
24 | @GetMapping("/updateParam/{flag}")
25 | public String updateParam(@PathVariable("flag") String flag) {
26 | agent.loadHdfsModel(flag);
27 | return "success";
28 | }
29 |
30 | @PostMapping("/updateModel")
31 | public String updateStreamParam(HttpServletRequest request) throws IOException, ServletException {
32 | for (Part part : request.getParts()) {
33 | log.info("content type: {}", part.getContentType()); //text/plain
34 | log.info("parameter name: {}", part.getName()); // actor
35 | log.info("file name: {}", part.getSubmittedFileName()); // ctor.param
36 | log.info("file size: {}", part.getSize());
37 | agent.loadSteamModel(part.getInputStream(), part.getSubmittedFileName());
38 | log.info("update model completed!");
39 | }
40 | return "success";
41 | }
42 |
43 | @GetMapping("/loadModel/{flag}")
44 | public String loadLocalModel(@PathVariable("flag") String flag) {
45 | agent.loadModel(flag);
46 | return "success";
47 | }
48 | }
--------------------------------------------------------------------------------
/edge-node/src/main/java/cn/edu/scut/controller/UserController.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut.controller;
2 |
3 | import cn.edu.scut.bean.Task;
4 | import cn.edu.scut.service.EdgeNodeSystemService;
5 | import lombok.extern.apachecommons.CommonsLog;
6 | import org.springframework.web.bind.annotation.*;
7 |
8 | import javax.annotation.Resource;
9 | import java.time.LocalDateTime;
10 |
11 | @RestController
12 | @RequestMapping(value = "/user", method = {RequestMethod.GET, RequestMethod.POST})
13 | @CommonsLog
14 | public class UserController {
15 | @Resource
16 | EdgeNodeSystemService edgeNodeSystemService;
17 |
18 | @PostMapping("/task")
19 | public String receiveUserTask(@RequestBody Task task) {
20 | log.info("receive task from user");
21 | task.setArrivalTime(LocalDateTime.now());
22 | edgeNodeSystemService.processUserTask(task);
23 | return "success";
24 | }
25 | }
26 |
--------------------------------------------------------------------------------
/edge-node/src/main/java/cn/edu/scut/queue/ExecutionQueue.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut.queue;
2 |
3 | import cn.edu.scut.bean.Task;
4 | import cn.edu.scut.thread.ExecutionRunnable;
5 | import cn.edu.scut.utils.SpringBeanUtils;
6 | import lombok.Getter;
7 |
8 | import java.util.concurrent.LinkedBlockingQueue;
9 | import java.util.concurrent.ThreadPoolExecutor;
10 | import java.util.concurrent.TimeUnit;
11 |
12 | public class ExecutionQueue {
13 | // execution queue based on the queue in thread pool
14 | @Getter
15 | private ThreadPoolExecutor executor = new ThreadPoolExecutor(1, 1, 0, TimeUnit.MILLISECONDS, new LinkedBlockingQueue<>());
16 |
17 | public void add(Task task) {
18 | ExecutionRunnable runnable = SpringBeanUtils.applicationContext.getBean(ExecutionRunnable.class);
19 | runnable.setTask(task);
20 | executor.execute(runnable);
21 | }
22 | public int getSize(){
23 | return executor.getQueue().size();
24 | }
25 | }
26 |
--------------------------------------------------------------------------------
/edge-node/src/main/java/cn/edu/scut/queue/TransmissionQueue.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut.queue;
2 |
3 | import cn.edu.scut.bean.Task;
4 | import cn.edu.scut.thread.TransmissionRunnable;
5 | import cn.edu.scut.utils.SpringBeanUtils;
6 |
7 | import java.util.concurrent.LinkedBlockingQueue;
8 | import java.util.concurrent.ThreadPoolExecutor;
9 | import java.util.concurrent.TimeUnit;
10 |
11 |
12 | public class TransmissionQueue {
13 | // transmission queue based on the queue in thread pool
14 | ThreadPoolExecutor thread = new ThreadPoolExecutor(1, 1, 0, TimeUnit.MILLISECONDS, new LinkedBlockingQueue<>());
15 |
16 | public void add(Task task) {
17 | TransmissionRunnable runnable = SpringBeanUtils.applicationContext.getBean(TransmissionRunnable.class);
18 | runnable.setTask(task);
19 | thread.execute(runnable);
20 | }
21 | }
22 |
--------------------------------------------------------------------------------
/edge-node/src/main/java/cn/edu/scut/scheduler/DRLScheduler.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut.scheduler;
2 |
3 |
4 | import cn.edu.scut.agent.IMAAgent;
5 | import cn.edu.scut.bean.Task;
6 | import cn.edu.scut.service.TaskService;
7 | import cn.edu.scut.service.TransitionService;
8 | import cn.edu.scut.util.ArrayUtils;
9 | import lombok.extern.slf4j.Slf4j;
10 | import org.springframework.beans.factory.annotation.Autowired;
11 | import org.springframework.beans.factory.annotation.Value;
12 | import org.springframework.cloud.context.config.annotation.RefreshScope;
13 | import org.springframework.stereotype.Service;
14 |
15 | @Service
16 | @Slf4j
17 | @RefreshScope
18 | public class DRLScheduler implements IScheduler {
19 |
20 | // heuristic method
21 | @Autowired(required = false)
22 | private IMAAgent agent;
23 |
24 | @Autowired
25 | private TransitionService transitionService;
26 |
27 | @Autowired
28 | private TaskService taskService;
29 |
30 | @Value("${edgeComputing.edgeNodeNumber}")
31 | private int agentNumber;
32 |
33 | @Value("${spring.application.name}")
34 | public String name;
35 |
36 | @Override
37 | public String selectAction(Long taskId) {
38 | Task task = taskService.getById(taskId);
39 | var availAction = new int[agentNumber + 1];
40 | availAction[agentNumber] = 0;
41 | for (int i = 1; i <= agentNumber; i++) {
42 | availAction[i - 1] = 1;
43 | }
44 | String s = ArrayUtils.arrayToString(availAction);
45 | task.setAvailAction(s);
46 | taskService.updateById(task);
47 | var state = transitionService.getState(taskId);
48 | int i = agent.selectAction(state, availAction, false) + 1;
49 | log.info("drl select action {}", i);
50 | return String.format("edge-node-%d", i);
51 | }
52 | }
--------------------------------------------------------------------------------
/edge-node/src/main/java/cn/edu/scut/scheduler/IScheduler.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut.scheduler;
2 |
3 | public interface IScheduler {
4 | String selectAction(Long taskId);
5 | }
6 |
--------------------------------------------------------------------------------
/edge-node/src/main/java/cn/edu/scut/scheduler/RandomScheduler.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut.scheduler;
2 |
3 | import lombok.Setter;
4 | import org.springframework.beans.factory.annotation.Autowired;
5 | import org.springframework.beans.factory.annotation.Value;
6 | import org.springframework.context.annotation.Lazy;
7 | import org.springframework.stereotype.Component;
8 |
9 | import java.util.ArrayList;
10 | import java.util.Random;
11 |
12 | @Lazy
13 | @Component
14 | @Setter
15 | public class RandomScheduler implements IScheduler {
16 | @Autowired
17 | private Random schedulerRandom;
18 |
19 | @Value("${edgeComputing.edgeNodeNumber}")
20 | private int edgeNodeNumber;
21 |
22 | @Override
23 | public String selectAction(Long taskId) {
24 | var services = new ArrayList();
25 | for (int i = 1; i <= edgeNodeNumber; i++) {
26 | services.add(String.format("edge-node-%d", i));
27 | }
28 | int index = schedulerRandom.nextInt(services.size());
29 | return services.get(index);
30 | }
31 | }
32 |
--------------------------------------------------------------------------------
/edge-node/src/main/java/cn/edu/scut/scheduler/ReactiveScheduler.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut.scheduler;
2 |
3 | import cn.edu.scut.bean.EdgeNodeSystem;
4 | import org.springframework.beans.factory.annotation.Autowired;
5 | import org.springframework.beans.factory.annotation.Value;
6 | import org.springframework.context.annotation.Lazy;
7 | import org.springframework.stereotype.Component;
8 | import org.springframework.web.client.RestTemplate;
9 |
10 | import java.util.*;
11 |
12 | @Lazy
13 | @Component
14 | public class ReactiveScheduler implements IScheduler {
15 |
16 | @Value("${spring.application.name}")
17 | public String name;
18 |
19 | @Autowired
20 | private EdgeNodeSystem edgeNodeSystem;
21 |
22 | @Autowired
23 | private RestTemplate restTemplate;
24 |
25 | @Value("${edgeComputing.edgeNodeNumber}")
26 | private int edgeNodeNumber;
27 |
28 | @Value("${heuristic.queue-coef}")
29 | private float queueCoef;
30 |
31 | @Autowired
32 | private Random schedulerRandom;
33 |
34 | @Override
35 | public String selectAction(Long taskId) {
36 | if (edgeNodeSystem.getExecutionQueue().getSize() < edgeNodeSystem.getExecutionQueueThreshold() * queueCoef) {
37 | return name;
38 | }
39 | var edgeNodeIds = new ArrayList();
40 | for (int i = 1; i <= edgeNodeNumber; i++) {
41 | edgeNodeIds.add(String.format("edge-node-%d", i));
42 | }
43 | var selectedNodes = new HashSet();
44 | // while true
45 | for (int i = 0; i < 1000; i++) {
46 | var edgeNodeId = edgeNodeIds.get(schedulerRandom.nextInt(edgeNodeIds.size()));
47 | if (edgeNodeIds.size() > 2 && edgeNodeId.equals(name)) {
48 | continue;
49 | }
50 | selectedNodes.add(edgeNodeId);
51 | if (selectedNodes.size() == 2) {
52 | break;
53 | }
54 | }
55 | var queue = new PriorityQueue<>(Comparator.comparingInt((Map o) -> (int) o.get("queue")));
56 | for (String edgeNodeId : selectedNodes) {
57 | var info = new HashMap();
58 | var url = String.format("http://%s/edgeNode/queue", edgeNodeId);
59 | var queueSize = restTemplate.getForObject(url, Integer.class);
60 | info.put("queue", queueSize);
61 | info.put("edgeId", edgeNodeId);
62 | queue.add(info);
63 | }
64 | return (String) queue.poll().get("edgeId");
65 | }
66 | }
--------------------------------------------------------------------------------
/edge-node/src/main/java/cn/edu/scut/scheduler/ReliabilityTwoChoice.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut.scheduler;
2 |
3 | import cn.edu.scut.bean.RatcVo;
4 | import cn.edu.scut.service.TaskService;
5 | import lombok.extern.slf4j.Slf4j;
6 | import org.springframework.beans.factory.annotation.Autowired;
7 | import org.springframework.beans.factory.annotation.Value;
8 | import org.springframework.context.annotation.Lazy;
9 | import org.springframework.stereotype.Component;
10 | import org.springframework.web.client.RestTemplate;
11 |
12 | import java.util.ArrayList;
13 | import java.util.HashSet;
14 | import java.util.Random;
15 |
16 | @Slf4j
17 | @Lazy
18 | @Component
19 | public class ReliabilityTwoChoice implements IScheduler {
20 |
21 | @Autowired
22 | private Random schedulerRandom;
23 |
24 | @Autowired
25 | private RestTemplate restTemplate;
26 |
27 | @Autowired
28 | private TaskService taskService;
29 |
30 | @Value("${edgeComputing.edgeNodeNumber}")
31 | private int edgeNodeNumber;
32 |
33 | @Override
34 | public String selectAction(Long taskId) {
35 | var task = taskService.getById(taskId);
36 |
37 | var edgeNodeIds = new ArrayList();
38 | for (int i = 1; i <= edgeNodeNumber; i++) {
39 | edgeNodeIds.add(String.format("edge-node-%d", i));
40 | }
41 | var selectedNodes = new HashSet();
42 | // while true
43 | for (int i = 0; i < 1000; i++) {
44 | selectedNodes.add(edgeNodeIds.get(schedulerRandom.nextInt(edgeNodeIds.size())));
45 | if (selectedNodes.size() == 2) {
46 | break;
47 | }
48 | }
49 |
50 | var selectEdgeNodeInfo = new ArrayList();
51 | for (String edgeNodeId : selectedNodes) {
52 | var url1 = String.format("http://%s/edgeNode/ratc", edgeNodeId);
53 | var ratcVo = restTemplate.getForObject(url1, RatcVo.class);
54 | selectEdgeNodeInfo.add(ratcVo);
55 | }
56 |
57 | boolean flag = true;
58 | for (RatcVo ratcVo : selectEdgeNodeInfo) {
59 | long executionTime = (long) (task.getCpuCycle().doubleValue() / ratcVo.getCapacity().doubleValue() * 1000);
60 | ratcVo.setTotalTime(executionTime + ratcVo.getWaitingTime());
61 | if (ratcVo.getTotalTime() > task.getDeadline()) {
62 | flag = false;
63 | }
64 | }
65 | if (flag) {
66 | var failureRate1 = selectEdgeNodeInfo.get(0).getExecutionFailureRate();
67 | var failureRate2 = selectEdgeNodeInfo.get(1).getExecutionFailureRate();
68 | if (failureRate1 < failureRate2) {
69 | return selectEdgeNodeInfo.get(0).getEdgeId();
70 | } else {
71 | return selectEdgeNodeInfo.get(1).getEdgeId();
72 | }
73 | } else {
74 | var time1 = selectEdgeNodeInfo.get(0).getTotalTime();
75 | var time2 = selectEdgeNodeInfo.get(1).getTotalTime();
76 | if (time1 < time2) {
77 | return selectEdgeNodeInfo.get(0).getEdgeId();
78 | } else {
79 | return selectEdgeNodeInfo.get(1).getEdgeId();
80 | }
81 | }
82 | }
83 | }
84 |
--------------------------------------------------------------------------------
/edge-node/src/main/java/cn/edu/scut/service/EdgeNodeSystemService.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut.service;
2 |
3 | import cn.edu.scut.bean.*;
4 | import cn.edu.scut.queue.ExecutionQueue;
5 | import cn.edu.scut.queue.TransmissionQueue;
6 | import cn.edu.scut.scheduler.*;
7 | import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
8 | import lombok.Getter;
9 | import lombok.Setter;
10 | import lombok.extern.slf4j.Slf4j;
11 | import org.springframework.beans.factory.annotation.Autowired;
12 | import org.springframework.beans.factory.annotation.Value;
13 | import org.springframework.cloud.context.config.annotation.RefreshScope;
14 | import org.springframework.context.annotation.Lazy;
15 | import org.springframework.scheduling.annotation.Async;
16 | import org.springframework.stereotype.Service;
17 | import org.springframework.web.client.RestTemplate;
18 |
19 | import java.time.LocalDateTime;
20 | import java.util.HashMap;
21 |
22 | @Setter
23 | @Getter
24 | @Service
25 | @Slf4j
26 | @RefreshScope
27 | public class EdgeNodeSystemService {
28 |
29 | @Value("${spring.application.name}")
30 | private String name;
31 |
32 | @Value("${edgeComputing.seed}")
33 | private Integer seed;
34 |
35 | @Value("${edgeComputing.cpuCapacity}")
36 | private Integer cpuCapacity;
37 |
38 | // refresh
39 | @Value("${edgeComputing.scheduler}")
40 | private String scheduler;
41 |
42 | @Value("${edgeComputing.minCpuCore}")
43 | private int minCpuCore;
44 |
45 | @Autowired
46 | private RestTemplate restTemplate;
47 |
48 | @Autowired
49 | private EdgeNodeSystem edgeNodeSystem;
50 |
51 | @Lazy
52 | @Autowired
53 | private DRLScheduler DRLScheduler;
54 |
55 | @Lazy
56 | @Autowired
57 | private RandomScheduler randomScheduler;
58 |
59 | @Lazy
60 | @Autowired
61 | private ReliabilityTwoChoice reliabilityTwoChoice;
62 |
63 | @Lazy
64 | @Autowired
65 | private ReactiveScheduler reactiveScheduler;
66 |
67 | @Autowired
68 | private EdgeNodeService edgeNodeService;
69 |
70 | @Autowired
71 | private LinkService linkService;
72 |
73 | @Value("${edgeComputing.queueCoef}")
74 | private float queueCoef;
75 |
76 | @Async
77 | public void processUserTask(Task task) {
78 | String service = switch (scheduler) {
79 | case "rl" -> DRLScheduler.selectAction(task.getId());
80 | case "random" -> randomScheduler.selectAction(task.getId());
81 | case "reactive" -> reactiveScheduler.selectAction(task.getId());
82 | case "reliability-two-choice" -> reliabilityTwoChoice.selectAction(task.getId());
83 | default -> throw new RuntimeException("no scheduler");
84 | };
85 | if (service.equals(name)) {
86 | task.setBeginExecutionTime(LocalDateTime.now());
87 | task.setEndTransmissionTime(LocalDateTime.now());
88 | task.setTransmissionTime(0L);
89 | task.setTransmissionWaitingTime(0L);
90 | task.setDestination(name);
91 | processEdgeNodeTask(task);
92 | } else {
93 | Link link = edgeNodeSystem.getLinkMap().get(service);
94 | Double transmissionTime = task.getTaskSize() / link.getTransmissionRate() * 1000;
95 | task.setTransmissionTime(transmissionTime.longValue());
96 | task.setDestination(service);
97 | edgeNodeSystem.getTransmissionQueueMap().get(service).add(task);
98 | }
99 | }
100 |
101 | public void processEdgeNodeTask(Task task) {
102 | Double executionTime = task.getCpuCycle().doubleValue() / edgeNodeSystem.getEdgeNode().getCapacity().doubleValue() * 1000;
103 | task.setExecutionTime(executionTime.longValue());
104 | edgeNodeSystem.getExecutionQueue().add(task);
105 | }
106 |
107 | public void init() {
108 | var edgeNodeConfig = edgeNodeService.getOne(new QueryWrapper().eq("name", name));
109 | var id = Integer.parseInt(name.split("-")[2]);
110 | EdgeNode edgeNode = new EdgeNode();
111 | edgeNode.setId(id);
112 | edgeNode.setName(name);
113 | edgeNode.setCpuNum(edgeNodeConfig.getCpuNum());
114 | edgeNode.setExecutionFailureRate(edgeNodeConfig.getExecutionFailureRate());
115 | edgeNode.setTaskRate(edgeNodeConfig.getTaskRate());
116 | edgeNode.setCapacity(edgeNodeConfig.getCpuNum() * Constants.Giga.value * cpuCapacity);
117 | edgeNodeSystem.setEdgeNode(edgeNode);
118 | edgeNodeSystem.setExecutionQueue(new ExecutionQueue());
119 |
120 | var transmissionQueueMap = new HashMap();
121 | var linkMap = new HashMap();
122 | var links = linkService.list(new QueryWrapper().eq("source", name));
123 | for (Link link : links) {
124 | transmissionQueueMap.put(link.getDestination(), new TransmissionQueue());
125 | linkMap.put(link.getDestination(), link);
126 | }
127 | edgeNodeSystem.setTransmissionQueueMap(transmissionQueueMap);
128 | edgeNodeSystem.setLinkMap(linkMap);
129 | // availAction
130 | edgeNodeSystem.setExecutionQueueThreshold(Float.valueOf(edgeNode.getCpuNum()) / (float) minCpuCore * queueCoef);
131 | log.info("load edge nodes and links configuration completed");
132 | log.info("{} edge config: {}", name, edgeNode);
133 | log.info("{} link config:{}", name, links);
134 | }
135 | }
136 |
--------------------------------------------------------------------------------
/edge-node/src/main/java/cn/edu/scut/thread/ExecutionRunnable.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut.thread;
2 |
3 | import cn.edu.scut.bean.EdgeNodeSystem;
4 | import cn.edu.scut.bean.Task;
5 | import cn.edu.scut.bean.TaskStatus;
6 | import cn.edu.scut.service.TaskService;
7 | import lombok.Getter;
8 | import lombok.Setter;
9 | import org.springframework.beans.factory.annotation.Autowired;
10 | import org.springframework.cloud.context.config.annotation.RefreshScope;
11 | import org.springframework.context.annotation.Scope;
12 | import org.springframework.stereotype.Component;
13 |
14 | import java.time.Duration;
15 | import java.time.LocalDateTime;
16 | import java.util.Random;
17 |
18 | @Component
19 | @Setter
20 | @Scope("prototype")
21 | @RefreshScope
22 | public class ExecutionRunnable implements Runnable {
23 |
24 | // prototype
25 | @Getter
26 | private Task task;
27 | @Autowired
28 | private TaskService taskService;
29 |
30 | @Autowired
31 | private EdgeNodeSystem edgeNodeSystem;
32 |
33 | @Autowired
34 | private Random reliabilityRandom;
35 |
36 | @Override
37 | public void run() {
38 | task.setBeginExecutionTime(LocalDateTime.now());
39 | task.setExecutionWaitingTime(Duration.between(task.getEndTransmissionTime(), task.getBeginExecutionTime()).toMillis());
40 | // drop the task without wasting the resource.
41 | long estimatedTotalTime = task.getTransmissionWaitingTime() + task.getTransmissionTime() + task.getExecutionWaitingTime() + task.getExecutionTime();
42 | if (estimatedTotalTime > task.getDeadline()) {
43 | task.setStatus(TaskStatus.DROP);
44 | taskService.updateById(task);
45 | return;
46 | }
47 | try {
48 | Thread.sleep(task.getExecutionTime());
49 |
50 | } catch (InterruptedException e) {
51 | e.printStackTrace();
52 | }
53 | task.setEndExecutionTime(LocalDateTime.now());
54 | // reliability
55 | double reliability = Math.exp(-task.getExecutionTime() / 1000.0 * edgeNodeSystem.getEdgeNode().getExecutionFailureRate());
56 | if (reliabilityRandom.nextDouble() > reliability) {
57 | task.setStatus(TaskStatus.EXECUTION_FAILURE);
58 | }
59 | else {
60 | task.setStatus(TaskStatus.SUCCESS);
61 | }
62 | taskService.updateById(task);
63 | }
64 | }
65 |
--------------------------------------------------------------------------------
/edge-node/src/main/java/cn/edu/scut/thread/TransmissionRunnable.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut.thread;
2 |
3 | import cn.edu.scut.bean.EdgeNodeSystem;
4 | import cn.edu.scut.bean.Link;
5 | import cn.edu.scut.bean.Task;
6 | import cn.edu.scut.bean.TaskStatus;
7 | import cn.edu.scut.service.TaskService;
8 | import lombok.Setter;
9 | import lombok.extern.apachecommons.CommonsLog;
10 | import org.springframework.beans.factory.annotation.Autowired;
11 | import org.springframework.cloud.context.config.annotation.RefreshScope;
12 | import org.springframework.context.annotation.Scope;
13 | import org.springframework.stereotype.Component;
14 | import org.springframework.web.client.RestTemplate;
15 |
16 | import java.time.Duration;
17 | import java.time.LocalDateTime;
18 | import java.util.Random;
19 |
20 | @Component
21 | @Scope("prototype")
22 | @Setter
23 | @RefreshScope
24 | @CommonsLog
25 | public class TransmissionRunnable implements Runnable {
26 | // prototype
27 | private Task task;
28 |
29 | @Autowired
30 | private EdgeNodeSystem edgeNodeSystem;
31 |
32 | @Autowired
33 | private Random reliabilityRandom;
34 |
35 | @Autowired
36 | private RestTemplate restTemplate;
37 |
38 | @Autowired
39 | private TaskService taskService;
40 |
41 | @Override
42 | public void run() {
43 | task.setBeginTransmissionTime(LocalDateTime.now());
44 | task.setTransmissionWaitingTime(Duration.between(task.getArrivalTime(), task.getBeginTransmissionTime()).toMillis());
45 | try {
46 | Thread.sleep(task.getTransmissionTime());
47 | } catch (InterruptedException e) {
48 | e.printStackTrace();
49 | }
50 | task.setEndTransmissionTime(LocalDateTime.now());
51 | // reliability
52 | String destination = task.getDestination();
53 | Link link = edgeNodeSystem.getLinkMap().get(destination);
54 | Double transmissionFailureRate = link.getTransmissionFailureRate();
55 | double reliability = Math.exp(-task.getTransmissionTime() / 1000.0 * transmissionFailureRate);
56 | if (reliabilityRandom.nextDouble() > reliability) {
57 | task.setStatus(TaskStatus.TRANSMISSION_FAILURE);
58 | // fixed null type exception
59 | task.setExecutionWaitingTime(0L);
60 | task.setExecutionTime(0L);
61 | taskService.updateById(task);
62 | } else {
63 | // seed task to other edge node
64 | String url = String.format("http://%s/edgeNode/task", destination);
65 | restTemplate.postForObject(url, task, String.class);
66 | }
67 | }
68 | }
69 |
--------------------------------------------------------------------------------
/edge-node/src/main/java/cn/edu/scut/utils/SpringBeanUtils.java:
--------------------------------------------------------------------------------
1 | package cn.edu.scut.utils;
2 |
3 | import org.springframework.context.ApplicationContext;
4 |
5 | public class SpringBeanUtils {
6 | public static ApplicationContext applicationContext;
7 |
8 | public static void setApplicationContext(ApplicationContext applicationContext) {
9 | SpringBeanUtils.applicationContext = applicationContext;
10 | }
11 | }
12 |
--------------------------------------------------------------------------------
/edge-node/src/main/resources/META-INF/MANIFEST.MF:
--------------------------------------------------------------------------------
1 | Manifest-Version: 1.0
2 | Main-Class: cn.edu.scut.EdgeNodeApp
3 |
4 |
--------------------------------------------------------------------------------
/edge-node/src/main/resources/bootstrap.yaml:
--------------------------------------------------------------------------------
1 | spring:
2 | application:
3 | name: ${spring_application_name:edge-node-1}
4 | profiles:
5 | # active: edge-computing,reactive
6 | # active: edge-computing,reliability-two-choice
7 | # active: edge-computing,mappo
8 | # active: edge-computing,masac
9 | active: edge-computing,random
10 | cloud:
11 | nacos:
12 | discovery:
13 | server-addr: 222.201.187.50:30848
14 | watch-delay: 3000
15 | config:
16 | server-addr: 222.201.187.50:30848
17 | prefix: application
18 | file-extension: yaml
--------------------------------------------------------------------------------
/k8s/edge/edge-controller.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: v1
2 | kind: Pod
3 | metadata:
4 | namespace: edge-computing
5 | name: edge-controller
6 | labels:
7 | app: edge-controller
8 | spec:
9 | containers:
10 | - name: edge-controller
11 | image: edge-controller:v1.0
12 | imagePullPolicy: IfNotPresent
13 | restartPolicy: Always
--------------------------------------------------------------------------------
/k8s/edge/edge-experiment.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: v1
2 | kind: Pod
3 | metadata:
4 | namespace: edge-computing
5 | name: edge-experiment
6 | labels:
7 | app: edge-experiment
8 | spec:
9 | containers:
10 | - name: edge-experiment
11 | image: edge-experiment:v1.0
12 | imagePullPolicy: IfNotPresent
13 | restartPolicy: Always
--------------------------------------------------------------------------------
/k8s/edge/edge-node.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: v1
2 | kind: Pod
3 | metadata:
4 | namespace: edge-computing
5 | name: $spring_application_name
6 | labels:
7 | app: $spring_application_name
8 | spec:
9 | containers:
10 | - name: $spring_application_name
11 | image: edge-node:v1.0
12 | imagePullPolicy: IfNotPresent
13 | env:
14 | - name: spring_application_name
15 | value: $spring_application_name
16 | restartPolicy: Always
--------------------------------------------------------------------------------
/k8s/mysql.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: v1
2 | kind: Pod
3 | metadata:
4 | labels:
5 | run: mysql
6 | name: mysql
7 | spec:
8 | containers:
9 | - env:
10 | - name: MYSQL_ROOT_PASSWORD
11 | value: "123456"
12 | image: mysql
13 | name: mysql
14 | volumeMounts:
15 | - mountPath: /var/lib/mysql
16 | name: data
17 | - mountPath: /etc/mysql/conf.d
18 | name: conf
19 | resources: { }
20 | volumes:
21 | - name: data
22 | nfs:
23 | path: /home/hongcai/lhc_data/mysql/data
24 | server: 222.201.187.50
25 | - name: conf
26 | nfs:
27 | path: /home/hongcai/lhc_data/mysql/conf
28 | server: 222.201.187.50
29 | dnsPolicy: ClusterFirst
30 | restartPolicy: Always
31 |
32 | ---
33 | apiVersion: v1
34 | kind: Service
35 | metadata:
36 | labels:
37 | run: mysql
38 | name: mysql
39 | spec:
40 | ports:
41 | - port: 3306
42 | protocol: TCP
43 | nodePort: 30306
44 | targetPort: 3306
45 | selector:
46 | run: mysql
47 | type: NodePort
--------------------------------------------------------------------------------
/k8s/nacos.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: v1
2 | kind: Pod
3 | metadata:
4 | labels:
5 | run: nacos
6 | name: nacos
7 | spec:
8 | containers:
9 | - env:
10 | - name: MODE
11 | value: standalone
12 | image: nacos/nacos-server:v2.1.2
13 | name: nacos
14 | volumeMounts:
15 | - mountPath: /home/nacos/data
16 | name: data
17 | resources: { }
18 | volumes:
19 | - name: data
20 | nfs:
21 | path: /home/hongcai/lhc_data/nacos/data
22 | server: 222.201.187.50
23 | dnsPolicy: ClusterFirst
24 | restartPolicy: Always
25 | ---
26 | apiVersion: v1
27 | kind: Service
28 | metadata:
29 | labels:
30 | run: nacos
31 | name: nacos
32 | spec:
33 | ports:
34 | - port: 8848
35 | name: port1
36 | protocol: TCP
37 | nodePort: 30848
38 | targetPort: 8848
39 | - port: 9848
40 | name: port2
41 | protocol: TCP
42 | nodePort: 31848
43 | targetPort: 9848
44 | selector:
45 | run: nacos
46 | type: NodePort
--------------------------------------------------------------------------------
/nacos/DEFAULT_GROUP/application-edge-computing.yaml:
--------------------------------------------------------------------------------
1 | edgeComputing:
2 | seed: 100
3 | edgeNodeSeed: 100
4 | taskSeed: 100
5 | reliabilitySeed: 100
6 | schedulerSeed: 100
7 | minTaskRate: 0
8 | maxTaskRate: 1.0
9 | cpuCapacity: 3
10 | minCpuCore: 4
11 | maxCpuCore: 32
12 | minTaskSize: 500
13 | maxTaskSize: 2500
14 | minTaskComplexity: 400
15 | maxTaskComplexity: 800
16 | deadline: 1000
17 | minTransmissionRate: 10
18 | maxTransmissionRate: 40
19 | minTransmissionFailureRate: 0
20 | maxTransmissionFailureRate: 0.03
21 | minExecutionFailureRate: 0
22 | maxExecutionFailureRate: 0.4
23 | timeSlot: 100
24 | queueCoef: 3.0
--------------------------------------------------------------------------------
/nacos/DEFAULT_GROUP/application-mappo.yaml:
--------------------------------------------------------------------------------
1 | edgeComputing:
2 | scheduler: rl
3 | flag: flag
4 |
5 | edgeNodeNumber: 10
6 | episodeLimit: 50
7 | episodeNumber: 50
8 |
9 | testNumber: 5
10 | testFrequency: 10
11 |
12 | taskSeed: 102
13 | reliabilitySeed: 102
14 | schedulerSeed: 102
15 | # runner: rl-test
16 | runner: rl-online
17 | rl:
18 | name: mappo
19 | learning-rate: 0.0003
20 | use-learning-rate-decay: false
21 | start-learning-rate: 0.0003
22 | end-learning-rate: 0.0
23 | use-clip-grad: true
24 | clip-grad-coef: 10
25 | use-normalized-reward: true
26 | gamma: 0.99
27 | use-gae: true
28 | gae-lambda: 0.95
29 | clip: 0.2
30 | use-entropy: false
31 | entropy-coef: 0.01
32 | hidden-shape: 64
33 | epoch: 4
34 |
35 | use-trained-model: false
36 | model-flag: model_flag
37 |
38 | buffer-size: 50
39 | batch-size: 50
40 | action-shape: 11
41 | state-shape: 52
42 |
--------------------------------------------------------------------------------
/nacos/DEFAULT_GROUP/application-masac.yaml:
--------------------------------------------------------------------------------
1 | edgeComputing:
2 | scheduler: rl
3 |
4 | flag: flag
5 |
6 | edgeNodeNumber: 10
7 | episodeLimit: 50
8 | episodeNumber: 200
9 |
10 | testNumber: 5
11 | testFrequency: 10
12 |
13 | taskSeed: 102
14 | reliabilitySeed: 102
15 | schedulerSeed: 102
16 | runner: rl-test
17 | # runner: rl-offline
18 | rl:
19 | name: masac
20 | learning-rate: 0.0003
21 | hidden-shape: 64
22 | use-normalized-reward: true
23 |
24 | training-time: 400
25 | test-frequency: 100
26 |
27 | buffer-path: buffer_path
28 | use-cql: true
29 | cql-weight: 0.1
30 | use-soft-update: true
31 | tau: 0.005
32 | use-adaptive-alpha: true
33 | alpha: 0.2
34 | target-entropy-coef: 0.95
35 | # alpha: 0.05
36 | gamma: 0.99
37 |
38 | buffer-size: 10000
39 | batch-size: 128
40 | action-shape: 11
41 | state-shape: 52
42 |
43 | # offline-to-online
44 | use-addition-critic: true
45 |
46 | use-trained-model: true
47 | model-flag: model_flag
--------------------------------------------------------------------------------
/nacos/DEFAULT_GROUP/application-random.yaml:
--------------------------------------------------------------------------------
1 | edgeComputing:
2 | scheduler: random
3 | flag: flag
4 |
5 | edgeNodeNumber: 10
6 | episodeLimit: 50
7 | episodeNumber: 200
8 |
9 | testNumber: 5
10 | testFrequency: 10
11 |
12 | # runner: heuristic-data
13 | runner: heuristic
14 | # runner: rl-offline
15 | # runner: rl-test
16 |
17 | taskSeed: 102
18 | reliabilitySeed: 102
19 | schedulerSeed: 102
20 |
21 | heuristic:
22 | name: random
23 |
--------------------------------------------------------------------------------
/nacos/DEFAULT_GROUP/application-reactive.yaml:
--------------------------------------------------------------------------------
1 | edgeComputing:
2 | scheduler: reactive
3 |
4 | testNumber: 5
5 | testFrequency: 10
6 |
7 | flag: flag
8 |
9 | edgeNodeNumber: 10
10 | episodeLimit: 50
11 | episodeNumber: 5
12 |
13 | taskSeed: 102
14 | reliabilitySeed: 102
15 | schedulerSeed: 102
16 | runner: heuristic
17 |
18 | heuristic:
19 | name: reactive
20 | queue-coef: 1.0
21 |
--------------------------------------------------------------------------------
/nacos/DEFAULT_GROUP/application-reliability-two-choice.yaml:
--------------------------------------------------------------------------------
1 | edgeComputing:
2 | scheduler: reliability-two-choice
3 | flag: flag
4 | edgeNodeNumber: 10
5 | episodeLimit: 50
6 | episodeNumber: 200
7 |
8 | testNumber: 5
9 | testFrequency: 10
10 |
11 | taskSeed: 102
12 | reliabilitySeed: 102
13 | schedulerSeed: 102
14 | runner: heuristic-data
15 | # runner: heuristic
16 |
17 |
18 | heuristic:
19 | name: reliability-two-choice
20 |
21 | rl:
22 | buffer-size: 10000
23 | batch-size: 64
24 | action-shape: 11
25 | state-shape: 52
--------------------------------------------------------------------------------
/nacos/DEFAULT_GROUP/application.yaml:
--------------------------------------------------------------------------------
1 | spring:
2 | datasource:
3 | type: com.alibaba.druid.pool.DruidDataSource
4 | driver-class-name: com.mysql.cj.jdbc.Driver
5 | url: jdbc:mysql://xxx.xxx.xxx.xxx:3306/edge_computing?useUnicode=true&characterEncoding=utf-8&useSSL=false&allowPublicKeyRetrieval=true&serverTimezone=GMT
6 | username: root
7 | password: 123456
8 |
9 | mybatis-plus:
10 | global-config:
11 | db-config:
12 | id-type: auto
13 | table-prefix: ec_
14 | hadoop:
15 | hdfs:
16 | url: hdfs://xxx.xxx.xxx.xxx:9000
--------------------------------------------------------------------------------
/pom.xml:
--------------------------------------------------------------------------------
1 |
2 |
5 | 4.0.0
6 |
7 | cn.edu.scut
8 | edge-computing
9 | pom
10 | 1.0-SNAPSHOT
11 |
12 | edge-node
13 | edge-controller
14 | edge-api
15 | edge-experiment
16 | edge-algorithm
17 |
18 |
19 |
20 | 17
21 | 17
22 |
23 |
24 |
25 |
26 | djl.ai
27 | https://oss.sonatype.org/content/repositories/snapshots/
28 |
29 |
30 |
31 |
32 |
33 |
34 | org.springframework.boot
35 | spring-boot-dependencies
36 | 2.7.5
37 | pom
38 | import
39 |
40 |
41 |
42 | org.springframework.cloud
43 | spring-cloud-dependencies
44 | 2021.0.4
45 | pom
46 | import
47 |
48 |
49 | com.alibaba.cloud
50 | spring-cloud-alibaba-dependencies
51 | 2021.0.4.0
52 | pom
53 | import
54 |
55 |
56 |
57 |
58 | ai.djl
59 | bom
60 | 0.19.0
61 | pom
62 | import
63 |
64 |
65 |
66 |
67 | com.alibaba
68 | druid-spring-boot-starter
69 | 1.2.15
70 |
71 |
72 | com.baomidou
73 | mybatis-plus-boot-starter
74 | 3.5.2
75 |
76 |
77 | mysql
78 | mysql-connector-java
79 | 8.0.31
80 |
81 |
82 | org.apache.commons
83 | commons-math3
84 | 3.6.1
85 |
86 |
87 | org.apache.hadoop
88 | hadoop-common
89 | 3.3.4
90 |
91 |
92 | org.apache.hadoop
93 | hadoop-hdfs
94 | 3.3.4
95 |
96 |
97 | org.apache.hadoop
98 | hadoop-hdfs-client
99 | 3.3.4
100 |
101 |
102 |
103 | org.springframework.cloud
104 | spring-cloud-starter-bootstrap
105 | 3.1.5
106 |
107 |
108 |
109 |
110 |
--------------------------------------------------------------------------------
/sql/create_table.sql:
--------------------------------------------------------------------------------
1 | create
2 | database edge_computing;
3 | use
4 | edge_computing;
5 |
6 | CREATE TABLE `ec_test_task`
7 | (
8 | `id` bigint NOT NULL AUTO_INCREMENT,
9 | `source` varchar(20) DEFAULT NULL,
10 | `destination` varchar(20) DEFAULT NULL,
11 | `status` varchar(20) DEFAULT NULL,
12 | `task_size` bigint DEFAULT NULL,
13 | `task_complexity` bigint DEFAULT NULL,
14 | `cpu_cycle` bigint DEFAULT NULL,
15 | `deadline` bigint DEFAULT NULL,
16 | `transmission_waiting_time` bigint DEFAULT NULL,
17 | `transmission_time` bigint DEFAULT NULL,
18 | `execution_waiting_time` bigint DEFAULT NULL,
19 | `execution_time` bigint DEFAULT NULL,
20 | PRIMARY KEY (`id`)
21 | ) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci
22 | DEFAULT CHARSET = utf8mb4
23 | COLLATE = utf8mb4_0900_ai_ci;
24 |
25 | CREATE TABLE `ec_edge_node`
26 | (
27 | `id` int NOT NULL AUTO_INCREMENT,
28 | `cpu_num` int DEFAULT NULL,
29 | `execution_failure_rate` double DEFAULT NULL,
30 | `task_rate` double DEFAULT NULL,
31 | `name` varchar(20) DEFAULT NULL,
32 | PRIMARY KEY (`id`)
33 | ) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci;
34 |
35 |
36 | CREATE TABLE `ec_link`
37 | (
38 | `id` int NOT NULL AUTO_INCREMENT,
39 | `source` varchar(20) NOT NULL,
40 | `destination` varchar(20) NOT NULL,
41 | `transmission_rate` double NOT NULL,
42 | `transmission_failure_rate` double NOT NULL,
43 | PRIMARY KEY (`id`)
44 | ) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci;
--------------------------------------------------------------------------------
/start-experiment.sh:
--------------------------------------------------------------------------------
1 | # maven
2 | mvn clean install -Dmaven.test.skip=true
3 |
4 | # docker
5 | ## edge-node
6 | cd ./edge-node || exit
7 | docker rmi edge-node:v1.0
8 | docker build -t edge-node:v1.0 .
9 | ## edge-controller
10 | cd ../edge-controller || exit
11 | docker rmi edge-controller:v1.0
12 | docker build -t edge-controller:v1.0 .
13 | ## edge-experiment
14 | cd ../edge-experiment || exit
15 | docker rmi edge-experiment:v1.0
16 | docker build -t edge-experiment:v1.0 .
17 |
18 | # kubernetes
19 | ## edge-node
20 | cd ../k8s/edge || exit
21 | for i in {1..10}; do
22 | export spring_application_name=edge-node-$i
23 | envsubst