timeProvider;
29 |
30 | /**
31 | * Creates a new {@link JShellSessionManager} with the specified values.
32 | *
33 | * @param config the config to use for creating the {@link JShellWrapper}s
34 | */
35 | public JShellSessionManager(Config config) {
36 | this.config = config;
37 | this.watchdogExecutorService = Executors.newSingleThreadScheduledExecutor();
38 | this.timeProvider = LocalDateTime::now;
39 |
40 | this.sessionMap = new ConcurrentHashMap<>();
41 |
42 | this.timeToLive = Objects.requireNonNull(
43 | config.getDuration("session.ttl"), "'session.ttl' not set"
44 | );
45 | this.maximumComputationTime = Objects.requireNonNull(
46 | config.getDuration("computation.allotted_time"), "'computation.allotted_time' not set"
47 | );
48 |
49 | this.ticker = new Thread(() -> {
50 | while (!Thread.currentThread().isInterrupted()) {
51 | purgeOld();
52 |
53 | try {
54 | Thread.sleep(timeToLive.dividedBy(2).toMillis());
55 | } catch (InterruptedException e) {
56 | LOGGER.warn("Session housekeeper was interrupted", e);
57 | break;
58 | }
59 | }
60 | }, "JShellSessionManager housekeeper");
61 |
62 | this.ticker.start();
63 | }
64 |
65 | /**
66 | * Returns the {@link JShellWrapper} for the user or creates a new.
67 | *
68 | * @param userId the id of the user
69 | * @return the {@link JShellWrapper} to use
70 | * @throws IllegalStateException if this manager was already shutdown via {@link #shutdown()}
71 | */
72 | public JShellWrapper getSessionOrCreate(String userId) {
73 | if (ticker == null) {
74 | throw new IllegalStateException("This manager was shutdown already.");
75 | }
76 | SessionEntry sessionEntry = sessionMap.computeIfAbsent(
77 | userId,
78 | id -> new SessionEntry(
79 | new JShellWrapper(config,
80 | new TimeWatchdog(watchdogExecutorService, maximumComputationTime)),
81 | id
82 | )
83 | );
84 |
85 | return sessionEntry.getJShell();
86 | }
87 |
88 | /**
89 | * Stops all activity of this manager (running thready, etx.) and frees its resources. You will no
90 | * longer be able to get a {@link jdk.jshell.JShell} from this manager.
91 | *
92 | * Should be called when the system is shut down.
93 | */
94 | public void shutdown() {
95 | // FIXME: 11.04.18 Actually call this to release resources when the bot shuts down
96 | ticker.interrupt();
97 | ticker = null;
98 | sessionMap.values().forEach(sessionEntry -> sessionEntry.getJShell().close());
99 | }
100 |
101 | /**
102 | * Purges sessions that were inactive for longer than the specified threshold.
103 | */
104 | void purgeOld() {
105 | LOGGER.debug("Starting purge");
106 | LocalDateTime now = timeProvider.get();
107 |
108 | // A session could potentially be marked for removal, then another threads retrieves it and updates its
109 | // last accessed state, leading to an unnecessary deletion. This should not have any impact on the caller
110 | // though.
111 | sessionMap.values().removeIf(sessionEntry -> {
112 | Duration delta = Duration.between(now, sessionEntry.getLastAccess()).abs();
113 |
114 | boolean tooOld = delta.compareTo(timeToLive) > 0;
115 |
116 | if (tooOld) {
117 | sessionEntry.getJShell().close();
118 |
119 | LOGGER.debug(
120 | "Removed session for '{}', difference was '{}'",
121 | sessionEntry.getUserId(), delta
122 | );
123 | }
124 |
125 | return tooOld;
126 | });
127 | }
128 |
129 | /**
130 | * Sets the used time provider. Useful for testing only.
131 | *
132 | * @param timeProvider the provider to use
133 | */
134 | void setTimeProvider(Supplier timeProvider) {
135 | this.timeProvider = timeProvider;
136 | }
137 |
138 | private static class SessionEntry {
139 |
140 | private JShellWrapper jshell;
141 | private String userId;
142 | private LocalDateTime lastAccess;
143 |
144 | SessionEntry(JShellWrapper jshell, String userId) {
145 | this.jshell = jshell;
146 | this.userId = userId;
147 | this.lastAccess = LocalDateTime.now();
148 | }
149 |
150 | /**
151 | * Returns the {@link JShellWrapper} and sets the {@link #getLastAccess()} to now.
152 | *
153 | * @return the associated {@link JShellWrapper}
154 | */
155 | JShellWrapper getJShell() {
156 | lastAccess = LocalDateTime.now();
157 | return jshell;
158 | }
159 |
160 | LocalDateTime getLastAccess() {
161 | return lastAccess;
162 | }
163 |
164 | String getUserId() {
165 | return userId;
166 | }
167 | }
168 | }
169 |
--------------------------------------------------------------------------------
/src/main/java/org/togetherjava/discord/server/execution/JShellWrapper.java:
--------------------------------------------------------------------------------
1 | package org.togetherjava.discord.server.execution;
2 |
3 | import java.io.OutputStream;
4 | import java.io.PrintStream;
5 | import java.nio.charset.StandardCharsets;
6 | import java.util.ArrayList;
7 | import java.util.Collections;
8 | import java.util.List;
9 | import java.util.Map;
10 | import java.util.concurrent.atomic.AtomicBoolean;
11 | import java.util.function.Supplier;
12 | import java.util.stream.Stream;
13 | import jdk.jshell.Diag;
14 | import jdk.jshell.JShell;
15 | import jdk.jshell.Snippet;
16 | import jdk.jshell.SnippetEvent;
17 | import jdk.jshell.SourceCodeAnalysis;
18 | import jdk.jshell.SourceCodeAnalysis.CompletionInfo;
19 | import org.togetherjava.discord.server.Config;
20 | import org.togetherjava.discord.server.io.StringOutputStream;
21 | import org.togetherjava.discord.server.sandbox.AgentAttacher;
22 | import org.togetherjava.discord.server.sandbox.FilteredExecutionControlProvider;
23 | import org.togetherjava.discord.server.sandbox.WhiteBlackList;
24 |
25 | /**
26 | * A light wrapper around {@link JShell}, providing additional features.
27 | */
28 | public class JShellWrapper {
29 |
30 | private static final int MAX_ANALYSIS_DEPTH = 40;
31 |
32 | private JShell jShell;
33 | private StringOutputStream outputStream;
34 | private TimeWatchdog watchdog;
35 |
36 | /**
37 | * Creates a new JShell wrapper using the given config and watchdog.
38 | *
39 | * @param config the config to gather properties from
40 | * @param watchdog the watchdog to schedule kill timer with
41 | */
42 | public JShellWrapper(Config config, TimeWatchdog watchdog) {
43 | this.watchdog = watchdog;
44 | this.outputStream = new StringOutputStream(Character.BYTES * 1600);
45 |
46 | this.jShell = buildJShell(outputStream, config);
47 |
48 | // Initialize JShell using the startup command
49 | jShell.eval(config.getStringOrDefault("java.startup-command", ""));
50 | }
51 |
52 | private JShell buildJShell(OutputStream outputStream, Config config) {
53 | PrintStream out = new PrintStream(outputStream, true, StandardCharsets.UTF_8);
54 | return JShell.builder()
55 | .out(out)
56 | .err(out)
57 | .remoteVMOptions(
58 | AgentAttacher.getCommandLineArgument(),
59 | "-Djava.security.policy=="
60 | + getClass().getResource("/jshell.policy").toExternalForm()
61 | )
62 | .executionEngine(getExecutionControlProvider(config), Map.of())
63 | .build();
64 | }
65 |
66 | private FilteredExecutionControlProvider getExecutionControlProvider(Config config) {
67 | return new FilteredExecutionControlProvider(WhiteBlackList.fromConfig(config));
68 | }
69 |
70 | /**
71 | * Closes the {@link JShell} session.
72 | *
73 | * @see JShell#close()
74 | */
75 | public void close() {
76 | jShell.close();
77 | }
78 |
79 | /**
80 | * Evaluates a command and returns the resulting snippet events and stdout.
81 | *
82 | * May throw an exception.
83 | *
84 | * @param command the command to run
85 | * @return the result of running it
86 | */
87 | public List eval(String command) {
88 | List elementaryCommands = breakApart(command);
89 |
90 | AtomicBoolean stopEvaluation = new AtomicBoolean(false);
91 |
92 | Supplier> work = () -> {
93 | List results = new ArrayList<>();
94 | for (String elementaryCommand : elementaryCommands) {
95 | if (stopEvaluation.get()) {
96 | break;
97 | }
98 | results.add(evalSingle(elementaryCommand));
99 | }
100 | return results;
101 | };
102 |
103 | Runnable killer = () -> {
104 | stopEvaluation.set(true);
105 | jShell.stop();
106 | };
107 |
108 | return watchdog.runWatched(work, killer);
109 | }
110 |
111 | /**
112 | * Evaluates a command and returns the resulting snippet events and stdout.
113 | *
114 | * May throw an exception.
115 | *
116 | * @param command the command to run
117 | * @return the result of running it
118 | */
119 | private JShellResult evalSingle(String command) {
120 | try {
121 | List evaluate = evaluate(command);
122 |
123 | return new JShellResult(evaluate, getStandardOut());
124 | } finally {
125 | // always remove the output stream so it does not linger in case of an exception
126 | outputStream.reset();
127 | }
128 | }
129 |
130 | /**
131 | * Returns the diagnostics for the snippet. This includes things like compilation errors.
132 | *
133 | * @param snippet the snippet to return them for
134 | * @return all found diagnostics
135 | */
136 | public Stream getSnippetDiagnostics(Snippet snippet) {
137 | return jShell.diagnostics(snippet);
138 | }
139 |
140 | private List evaluate(String command) {
141 | return jShell.eval(command);
142 | }
143 |
144 | private List breakApart(String input) {
145 | SourceCodeAnalysis sourceCodeAnalysis = jShell.sourceCodeAnalysis();
146 |
147 | CompletionInfo completionInfo = sourceCodeAnalysis.analyzeCompletion(input);
148 |
149 | int depthCounter = 0;
150 |
151 | List fullCommand = new ArrayList<>();
152 | // source can be null if the input is malformed (e.g. with a method with a syntax error inside)
153 | while (!completionInfo.remaining().isEmpty() && completionInfo.source() != null) {
154 | depthCounter++;
155 |
156 | // should not be needed, but a while true loop here blocks a whole thread with a busy loop and
157 | // might lead to an OOM if the fullCommand list overflows
158 | if (depthCounter > MAX_ANALYSIS_DEPTH) {
159 | break;
160 | }
161 |
162 | fullCommand.add(completionInfo.source());
163 | completionInfo = sourceCodeAnalysis.analyzeCompletion(completionInfo.remaining());
164 | }
165 |
166 | // the final one
167 | if (completionInfo.source() != null) {
168 | fullCommand.add(completionInfo.source());
169 | } else if (completionInfo.remaining() != null) {
170 | // or the remaining if it errored
171 | fullCommand.add(completionInfo.remaining());
172 | }
173 |
174 | return fullCommand;
175 | }
176 |
177 | private String getStandardOut() {
178 | return outputStream.toString();
179 | }
180 |
181 | /**
182 | * Wraps the result of executing JShell.
183 | */
184 | public static class JShellResult {
185 |
186 | private List events;
187 | private String stdout;
188 |
189 | JShellResult(List events, String stdout) {
190 | this.events = events;
191 | this.stdout = stdout == null ? "" : stdout;
192 | }
193 |
194 | public List getEvents() {
195 | return Collections.unmodifiableList(events);
196 | }
197 |
198 | public String getStdOut() {
199 | return stdout;
200 | }
201 | }
202 | }
203 |
--------------------------------------------------------------------------------
/src/main/java/org/togetherjava/discord/server/execution/TimeWatchdog.java:
--------------------------------------------------------------------------------
1 | package org.togetherjava.discord.server.execution;
2 |
3 | import java.time.Duration;
4 | import java.util.concurrent.ScheduledExecutorService;
5 | import java.util.concurrent.TimeUnit;
6 | import java.util.concurrent.atomic.AtomicBoolean;
7 | import java.util.concurrent.atomic.AtomicInteger;
8 | import java.util.function.Supplier;
9 | import org.slf4j.Logger;
10 | import org.slf4j.LoggerFactory;
11 |
12 | /**
13 | * Runs an action and cancels it if takes too long.
14 | *
15 | * This class is not thread safe and only one action may be running at a
16 | * time.
17 | */
18 | public class TimeWatchdog {
19 |
20 | private static final Logger LOGGER = LoggerFactory.getLogger(TimeWatchdog.class);
21 |
22 | private final ScheduledExecutorService watchdogThreadPool;
23 | private final Duration maxTime;
24 | private final AtomicInteger operationCounter;
25 |
26 | /**
27 | * Creates a new time watchdog running on the given executor service.
28 | *
29 | * @param watchdogThreadPool the executor service to run on
30 | * @param maxTime the maximum duration to allow
31 | */
32 | public TimeWatchdog(ScheduledExecutorService watchdogThreadPool, Duration maxTime) {
33 | this.watchdogThreadPool = watchdogThreadPool;
34 | this.maxTime = maxTime;
35 | this.operationCounter = new AtomicInteger();
36 | }
37 |
38 | /**
39 | * Runs an operation and cancels it if it takes too long.
40 | *
41 | * @param action the action to run
42 | * @param cancelAction cancels the passed action
43 | * @param the type of the result of the operation
44 | * @return the result of the operation
45 | */
46 | public T runWatched(Supplier action, Runnable cancelAction) {
47 | AtomicBoolean killed = new AtomicBoolean(false);
48 | int myId = operationCounter.incrementAndGet();
49 |
50 | watchdogThreadPool.schedule(() -> {
51 | // another calculation was done in the meantime.
52 | if (myId != operationCounter.get()) {
53 | return;
54 | }
55 |
56 | killed.set(true);
57 |
58 | cancelAction.run();
59 | LOGGER.debug("Killed a session (#" + myId + ")");
60 | }, maxTime.toMillis(), TimeUnit.MILLISECONDS);
61 |
62 | T result = action.get();
63 |
64 | if (killed.get()) {
65 | throw new AllottedTimeExceededException(maxTime);
66 | }
67 |
68 | return result;
69 | }
70 | }
71 |
--------------------------------------------------------------------------------
/src/main/java/org/togetherjava/discord/server/io/StringOutputStream.java:
--------------------------------------------------------------------------------
1 | package org.togetherjava.discord.server.io;
2 |
3 | import java.io.OutputStream;
4 | import java.nio.charset.StandardCharsets;
5 | import java.util.Arrays;
6 |
7 | /**
8 | * An output stream that writes to a string.
9 | */
10 | public class StringOutputStream extends OutputStream {
11 |
12 | private static final int INITIAL_BUFFER_SIZE = 64;
13 |
14 | private final int maxSize;
15 | private byte[] buffer;
16 | private int size;
17 |
18 | public StringOutputStream(int maxSize) {
19 | this.maxSize = maxSize;
20 |
21 | reset();
22 | }
23 |
24 | /**
25 | * Resets this {@link StringOutputStream}, also discarding the buffer.
26 | */
27 | public void reset() {
28 | buffer = new byte[INITIAL_BUFFER_SIZE];
29 | size = 0;
30 | }
31 |
32 | @Override
33 | public void write(int b) {
34 | ensureCapacity();
35 |
36 | if (size < buffer.length && size < maxSize) {
37 | buffer[size++] = (byte) b;
38 | }
39 | }
40 |
41 | private void ensureCapacity() {
42 | if (size >= buffer.length) {
43 | int newSize = size * 2;
44 |
45 | if (newSize > maxSize) {
46 | newSize = maxSize;
47 | }
48 |
49 | buffer = Arrays.copyOf(buffer, newSize);
50 | }
51 | }
52 |
53 | @Override
54 | public String toString() {
55 | if (size < 1) {
56 | return "";
57 | }
58 |
59 | return new String(Arrays.copyOf(buffer, size), StandardCharsets.UTF_8);
60 | }
61 | }
62 |
--------------------------------------------------------------------------------
/src/main/java/org/togetherjava/discord/server/io/input/InputSanitizer.java:
--------------------------------------------------------------------------------
1 | package org.togetherjava.discord.server.io.input;
2 |
3 | /**
4 | * Sanitizes input in some form to fix user errors.
5 | */
6 | public interface InputSanitizer {
7 |
8 | /**
9 | * Sanitizes the input to JShell so that errors in it might be accounted for.
10 | *
11 | * @param input the input to sanitize
12 | * @return the resulting input
13 | */
14 | String sanitize(String input);
15 | }
16 |
--------------------------------------------------------------------------------
/src/main/java/org/togetherjava/discord/server/io/input/InputSanitizerManager.java:
--------------------------------------------------------------------------------
1 | package org.togetherjava.discord.server.io.input;
2 |
3 | import java.util.ArrayList;
4 | import java.util.List;
5 |
6 | /**
7 | * Stores all registered input sanitizers and provides means to run them all.
8 | */
9 | public class InputSanitizerManager {
10 |
11 | private List sanitizers;
12 |
13 | public InputSanitizerManager() {
14 | this.sanitizers = new ArrayList<>();
15 | addDefaults();
16 | }
17 |
18 | private void addDefaults() {
19 | addSanitizer(new UnicodeQuoteSanitizer());
20 | }
21 |
22 | /**
23 | * Adds a new {@link InputSanitizer}
24 | *
25 | * @param sanitizer the sanitizer to add
26 | */
27 | public void addSanitizer(InputSanitizer sanitizer) {
28 | sanitizers.add(sanitizer);
29 | }
30 |
31 | /**
32 | * Sanitizes a given input using all registered {@link InputSanitizer}s.
33 | *
34 | * @param input the input to sanitize
35 | * @return the resulting input
36 | */
37 | public String sanitize(String input) {
38 | String result = input;
39 | for (InputSanitizer sanitizer : sanitizers) {
40 | result = sanitizer.sanitize(input);
41 | }
42 | return result;
43 | }
44 | }
45 |
--------------------------------------------------------------------------------
/src/main/java/org/togetherjava/discord/server/io/input/UnicodeQuoteSanitizer.java:
--------------------------------------------------------------------------------
1 | package org.togetherjava.discord.server.io.input;
2 |
3 | /**
4 | * An {@link InputSanitizer} that replaces unicode quotes (as inserted by word/phones) with regular
5 | * ones.
6 | */
7 | public class UnicodeQuoteSanitizer implements InputSanitizer {
8 |
9 | @Override
10 | public String sanitize(String input) {
11 | return input
12 | .replace("“", "\"")
13 | .replace("”", "\"");
14 | }
15 | }
16 |
--------------------------------------------------------------------------------
/src/main/java/org/togetherjava/discord/server/rendering/CompilationErrorRenderer.java:
--------------------------------------------------------------------------------
1 | package org.togetherjava.discord.server.rendering;
2 |
3 | import java.util.Locale;
4 | import jdk.jshell.Diag;
5 | import net.dv8tion.jda.api.EmbedBuilder;
6 | import net.dv8tion.jda.api.entities.MessageEmbed;
7 |
8 | /**
9 | * Renders error messages.
10 | */
11 | public class CompilationErrorRenderer implements Renderer {
12 |
13 | @Override
14 | public boolean isApplicable(Object param) {
15 | return param instanceof Diag;
16 | }
17 |
18 | @Override
19 | public EmbedBuilder render(Object object, EmbedBuilder builder) {
20 | Diag diag = (Diag) object;
21 | return builder
22 | .addField(
23 | "Error message",
24 | RenderUtils
25 | .truncateAndSanitize(diag.getMessage(Locale.ROOT), MessageEmbed.VALUE_MAX_LENGTH),
26 | false
27 | );
28 | }
29 | }
30 |
--------------------------------------------------------------------------------
/src/main/java/org/togetherjava/discord/server/rendering/ExceptionRenderer.java:
--------------------------------------------------------------------------------
1 | package org.togetherjava.discord.server.rendering;
2 |
3 | import java.util.Objects;
4 | import jdk.jshell.EvalException;
5 | import jdk.jshell.Snippet.Status;
6 | import net.dv8tion.jda.api.EmbedBuilder;
7 |
8 | /**
9 | * A renderer for exceptions.
10 | */
11 | public class ExceptionRenderer implements Renderer {
12 |
13 | @Override
14 | public boolean isApplicable(Object param) {
15 | return param instanceof Throwable;
16 | }
17 |
18 | @Override
19 | public EmbedBuilder render(Object object, EmbedBuilder builder) {
20 | RenderUtils.applyColor(Status.REJECTED, builder);
21 |
22 | Throwable throwable = (Throwable) object;
23 | builder
24 | .addField("Exception type", throwable.getClass().getSimpleName(), true)
25 | .addField("Message", Objects.toString(throwable.getMessage()), false);
26 |
27 | if (throwable.getCause() != null) {
28 | renderCause(1, throwable, builder);
29 | }
30 |
31 | if (throwable instanceof EvalException) {
32 | EvalException exception = (EvalException) throwable;
33 | builder.addField("Wraps", exception.getExceptionClassName(), true);
34 | }
35 |
36 | return builder;
37 | }
38 |
39 | private void renderCause(int index, Throwable throwable, EmbedBuilder builder) {
40 | builder
41 | .addField("Cause " + index + " type", throwable.getClass().getSimpleName(), false)
42 | .addField("Message", throwable.getMessage(), true);
43 |
44 | if (throwable.getCause() != null) {
45 | renderCause(index + 1, throwable.getCause(), builder);
46 | }
47 | }
48 | }
49 |
--------------------------------------------------------------------------------
/src/main/java/org/togetherjava/discord/server/rendering/RejectedColorRenderer.java:
--------------------------------------------------------------------------------
1 | package org.togetherjava.discord.server.rendering;
2 |
3 | import jdk.jshell.Snippet.Status;
4 | import jdk.jshell.SnippetEvent;
5 | import net.dv8tion.jda.api.EmbedBuilder;
6 | import org.togetherjava.discord.server.execution.JShellWrapper;
7 |
8 | /**
9 | * A renderer that adjusts the color depending on the status of the snippet.
10 | */
11 | public class RejectedColorRenderer implements Renderer {
12 |
13 | @Override
14 | public boolean isApplicable(Object param) {
15 | return param instanceof JShellWrapper.JShellResult;
16 | }
17 |
18 | @Override
19 | public EmbedBuilder render(Object object, EmbedBuilder builder) {
20 | JShellWrapper.JShellResult result = (JShellWrapper.JShellResult) object;
21 |
22 | for (SnippetEvent snippetEvent : result.getEvents()) {
23 | RenderUtils.applyColor(snippetEvent.status(), builder);
24 | if (snippetEvent.exception() != null) {
25 | RenderUtils.applyColor(Status.REJECTED, builder);
26 | break;
27 | }
28 | }
29 |
30 | return builder;
31 | }
32 | }
33 |
--------------------------------------------------------------------------------
/src/main/java/org/togetherjava/discord/server/rendering/RenderUtils.java:
--------------------------------------------------------------------------------
1 | package org.togetherjava.discord.server.rendering;
2 |
3 |
4 | import java.awt.Color;
5 | import jdk.jshell.Snippet;
6 | import net.dv8tion.jda.api.EmbedBuilder;
7 |
8 | /**
9 | * Contains utility functions for rendering.
10 | */
11 | class RenderUtils {
12 |
13 | static int NEWLINE_MAXIMUM = 10;
14 |
15 | private static final Color ERROR_COLOR = new Color(255, 99, 71);
16 | private static final Color SUCCESS_COLOR = new Color(118, 255, 0);
17 | private static final Color OVERWRITTEN_COLOR = SUCCESS_COLOR;
18 | private static final Color RECOVERABLE_COLOR = new Color(255, 181, 71);
19 |
20 | /**
21 | * Truncates the String to the max length and sanitizes it a bit.
22 | *
23 | * @param input the input string
24 | * @param maxLength the maximum length it can have
25 | * @return the processed string
26 | */
27 | static String truncateAndSanitize(String input, int maxLength) {
28 | StringBuilder result = new StringBuilder();
29 |
30 | int newLineCount = 0;
31 | for (int codePoint : input.codePoints().toArray()) {
32 | if (codePoint == '\n') {
33 | newLineCount++;
34 | }
35 |
36 | if (codePoint == '\n' && newLineCount > NEWLINE_MAXIMUM) {
37 | result.append("⏎");
38 | } else {
39 | result.append(Character.toChars(codePoint));
40 | }
41 | }
42 |
43 | return truncate(result.toString(), maxLength);
44 | }
45 |
46 | private static String truncate(String input, int maxLength) {
47 | if (input.length() <= maxLength) {
48 | return input;
49 | }
50 | return input.substring(0, maxLength);
51 | }
52 |
53 | /**
54 | * Applies the given color to the embed.
55 | *
56 | * @param status the status
57 | * @param builder the builder to apply it to
58 | */
59 | static void applyColor(Snippet.Status status, EmbedBuilder builder) {
60 | switch (status) {
61 | case VALID:
62 | builder.setColor(SUCCESS_COLOR);
63 | break;
64 | case OVERWRITTEN:
65 | builder.setColor(OVERWRITTEN_COLOR);
66 | break;
67 | case REJECTED:
68 | case DROPPED:
69 | case NONEXISTENT:
70 | builder.setColor(ERROR_COLOR);
71 | break;
72 | case RECOVERABLE_DEFINED:
73 | case RECOVERABLE_NOT_DEFINED:
74 | builder.setColor(RECOVERABLE_COLOR);
75 | break;
76 | }
77 | }
78 | }
79 |
--------------------------------------------------------------------------------
/src/main/java/org/togetherjava/discord/server/rendering/Renderer.java:
--------------------------------------------------------------------------------
1 | package org.togetherjava.discord.server.rendering;
2 |
3 | import net.dv8tion.jda.api.EmbedBuilder;
4 |
5 | /**
6 | * A renderer takes care of displaying some message in an embed.
7 | */
8 | public interface Renderer {
9 |
10 | /**
11 | * Checks if this renderer can render the given object.
12 | *
13 | * @param param the object to check
14 | * @return true if this renderer can handle the passed object
15 | */
16 | boolean isApplicable(Object param);
17 |
18 | /**
19 | * Renders the given object to the {@link EmbedBuilder}.
20 | *
21 | * @param object the object to render
22 | * @param builder the {@link EmbedBuilder} to modify
23 | * @return the rendered {@link EmbedBuilder}
24 | */
25 | EmbedBuilder render(Object object, EmbedBuilder builder);
26 | }
27 |
--------------------------------------------------------------------------------
/src/main/java/org/togetherjava/discord/server/rendering/RendererManager.java:
--------------------------------------------------------------------------------
1 | package org.togetherjava.discord.server.rendering;
2 |
3 | import java.util.ArrayList;
4 | import java.util.List;
5 | import jdk.jshell.Snippet.Status;
6 | import jdk.jshell.SnippetEvent;
7 | import net.dv8tion.jda.api.EmbedBuilder;
8 | import org.togetherjava.discord.server.execution.JShellWrapper;
9 |
10 | /**
11 | * Contains {@link Renderer}s and allows running them in series.
12 | */
13 | public class RendererManager {
14 |
15 | private List rendererList;
16 | private Renderer catchAll;
17 |
18 | public RendererManager() {
19 | this.rendererList = new ArrayList<>();
20 | this.catchAll = new StringCatchallRenderer();
21 |
22 | addRenderer(new ExceptionRenderer());
23 | addRenderer(new StandardOutputRenderer());
24 | addRenderer(new CompilationErrorRenderer());
25 | addRenderer(new RejectedColorRenderer());
26 | }
27 |
28 | /**
29 | * Adds the given renderer to this manager.
30 | *
31 | * @param renderer the renderer to add
32 | */
33 | private void addRenderer(Renderer renderer) {
34 | rendererList.add(renderer);
35 | }
36 |
37 | /**
38 | * Renders a given result to the passed {@link EmbedBuilder}.
39 | *
40 | * @param builder the builder to render to
41 | * @param result the {@link org.togetherjava.discord.server.execution.JShellWrapper.JShellResult}
42 | * to render
43 | */
44 | public void renderJShellResult(EmbedBuilder builder, JShellWrapper.JShellResult result) {
45 | RenderUtils.applyColor(Status.VALID, builder);
46 |
47 | renderObject(builder, result);
48 |
49 | for (SnippetEvent snippetEvent : result.getEvents()) {
50 | renderObject(builder, snippetEvent.exception());
51 | renderObject(builder, snippetEvent.value());
52 | }
53 | }
54 |
55 | /**
56 | * Renders an object to a builder.
57 | *
58 | * @param builder the builder to render to
59 | * @param object the object to render
60 | */
61 | public void renderObject(EmbedBuilder builder, Object object) {
62 | if (object == null) {
63 | return;
64 | }
65 |
66 | boolean rendered = false;
67 | for (Renderer renderer : rendererList) {
68 | if (renderer.isApplicable(object)) {
69 | rendered = true;
70 | renderer.render(object, builder);
71 | }
72 | }
73 |
74 | if (!rendered && catchAll.isApplicable(object)) {
75 | catchAll.render(object, builder);
76 | }
77 | }
78 | }
79 |
--------------------------------------------------------------------------------
/src/main/java/org/togetherjava/discord/server/rendering/StandardOutputRenderer.java:
--------------------------------------------------------------------------------
1 | package org.togetherjava.discord.server.rendering;
2 |
3 | import net.dv8tion.jda.api.EmbedBuilder;
4 | import net.dv8tion.jda.api.entities.MessageEmbed;
5 | import org.togetherjava.discord.server.execution.JShellWrapper;
6 |
7 | /**
8 | * A renderer for the standard output result.
9 | */
10 | public class StandardOutputRenderer implements Renderer {
11 |
12 | @Override
13 | public boolean isApplicable(Object param) {
14 | return param instanceof JShellWrapper.JShellResult;
15 | }
16 |
17 | @Override
18 | public EmbedBuilder render(Object object, EmbedBuilder builder) {
19 | JShellWrapper.JShellResult result = (JShellWrapper.JShellResult) object;
20 | if (result.getStdOut().isEmpty()) {
21 | return builder;
22 | }
23 | String output;
24 |
25 | // Discord rejects all-whitespace fields so we need to guard them with a code block
26 | // Inline code swallows leading and trailing whitespaces, so it is sadly not up to the task
27 | if (result.getStdOut().chars().allMatch(Character::isWhitespace)) {
28 | final int fenceLength = "```\n```".length();
29 | String inner = RenderUtils
30 | .truncateAndSanitize(result.getStdOut(), MessageEmbed.VALUE_MAX_LENGTH - fenceLength);
31 | output = "```\n" + inner + "```";
32 | } else {
33 | output = RenderUtils.truncateAndSanitize(result.getStdOut(), MessageEmbed.VALUE_MAX_LENGTH);
34 | }
35 |
36 | return builder.addField("Output", output, true);
37 | }
38 | }
39 |
--------------------------------------------------------------------------------
/src/main/java/org/togetherjava/discord/server/rendering/StringCatchallRenderer.java:
--------------------------------------------------------------------------------
1 | package org.togetherjava.discord.server.rendering;
2 |
3 | import java.util.Objects;
4 |
5 | import net.dv8tion.jda.api.EmbedBuilder;
6 | import net.dv8tion.jda.api.entities.MessageEmbed;
7 |
8 | /**
9 | * A renderer for results that just renders whatever hasn't been renderer yet as a string.
10 | */
11 | public class StringCatchallRenderer implements Renderer {
12 |
13 | @Override
14 | public boolean isApplicable(Object param) {
15 | return !Objects.toString(param).isEmpty();
16 | }
17 |
18 | @Override
19 | public EmbedBuilder render(Object object, EmbedBuilder builder) {
20 | return builder.addField(
21 | "Result",
22 | RenderUtils.truncateAndSanitize(Objects.toString(object), MessageEmbed.VALUE_MAX_LENGTH),
23 | true
24 | );
25 | }
26 | }
27 |
--------------------------------------------------------------------------------
/src/main/java/org/togetherjava/discord/server/sandbox/AgentAttacher.java:
--------------------------------------------------------------------------------
1 | package org.togetherjava.discord.server.sandbox;
2 |
3 | import java.nio.file.Path;
4 | import me.ialistannen.jvmagentutils.instrumentation.JvmUtils;
5 |
6 | public class AgentAttacher {
7 |
8 | private static final Path agentJar = JvmUtils.generateAgentJar(
9 | AgentMain.class, AgentMain.class, JshellSecurityManager.class
10 | );
11 |
12 | /**
13 | * Returns the command line argument that attaches the agent.
14 | *
15 | * @return the command line argument to start it
16 | */
17 | public static String getCommandLineArgument() {
18 | return "-javaagent:" + agentJar.toAbsolutePath();
19 | }
20 | }
21 |
--------------------------------------------------------------------------------
/src/main/java/org/togetherjava/discord/server/sandbox/AgentMain.java:
--------------------------------------------------------------------------------
1 | package org.togetherjava.discord.server.sandbox;
2 |
3 | import java.lang.instrument.ClassFileTransformer;
4 | import java.lang.instrument.Instrumentation;
5 |
6 | /**
7 | * An agent that sets the security manager JShell uses.
8 | */
9 | public class AgentMain implements ClassFileTransformer {
10 |
11 | public static void premain(String args, Instrumentation inst) {
12 | System.setSecurityManager(new JshellSecurityManager());
13 | }
14 | }
15 |
--------------------------------------------------------------------------------
/src/main/java/org/togetherjava/discord/server/sandbox/FilteredExecutionControl.java:
--------------------------------------------------------------------------------
1 | package org.togetherjava.discord.server.sandbox;
2 |
3 | import jdk.jshell.execution.LocalExecutionControl;
4 | import org.objectweb.asm.ClassReader;
5 | import org.objectweb.asm.ClassVisitor;
6 | import org.objectweb.asm.Handle;
7 | import org.objectweb.asm.MethodVisitor;
8 | import org.objectweb.asm.Opcodes;
9 | import org.slf4j.Logger;
10 | import org.slf4j.LoggerFactory;
11 |
12 | public class FilteredExecutionControl extends LocalExecutionControl {
13 |
14 | private static final Logger LOGGER = LoggerFactory.getLogger(FilteredExecutionControl.class);
15 |
16 | private final WhiteBlackList whiteBlackList;
17 |
18 | /**
19 | * Creates a new {@link FilteredExecutionControl}.
20 | *
21 | * @param whiteBlackList the {@link WhiteBlackList}
22 | */
23 | FilteredExecutionControl(WhiteBlackList whiteBlackList) {
24 | this.whiteBlackList = whiteBlackList;
25 | }
26 |
27 | @Override
28 | public void load(ClassBytecodes[] cbcs)
29 | throws ClassInstallException, NotImplementedException, EngineTerminationException {
30 | for (ClassBytecodes bytecodes : cbcs) {
31 | ClassReader classReader = new ClassReader(bytecodes.bytecodes());
32 | classReader.accept(new ClassVisitor(Opcodes.ASM6) {
33 | @Override
34 | public MethodVisitor visitMethod(int access, String name, String descriptor,
35 | String signature,
36 | String[] exceptions) {
37 | return new FilteringMethodVisitor();
38 | }
39 | }, 0);
40 | }
41 |
42 | super.load(cbcs);
43 | }
44 |
45 | private boolean isBlocked(String name) {
46 | return whiteBlackList.isBlocked(name);
47 | }
48 |
49 | private boolean isPackageOrParentBlocked(String sanitizedPackage) {
50 | if (sanitizedPackage == null || sanitizedPackage.isEmpty()) {
51 | return false;
52 | }
53 | if (isBlocked(sanitizedPackage)) {
54 | return true;
55 | }
56 |
57 | int nextDot = sanitizedPackage.lastIndexOf('.');
58 |
59 | return nextDot >= 0 && isPackageOrParentBlocked(sanitizedPackage.substring(0, nextDot));
60 | }
61 |
62 |
63 | private class FilteringMethodVisitor extends MethodVisitor {
64 |
65 | private FilteringMethodVisitor() {
66 | super(Opcodes.ASM6);
67 | }
68 |
69 | @Override
70 | public void visitMethodInsn(int opcode, String owner, String name, String descriptor,
71 | boolean isInterface) {
72 | checkAccess(owner, name);
73 | }
74 |
75 | @Override
76 | public void visitFieldInsn(int opcode, String owner, String name, String descriptor) {
77 | checkAccess(owner, name);
78 | }
79 |
80 | private void checkAccess(String owner, String name) {
81 | String sanitizedClassName = sanitizeClassName(owner);
82 |
83 | if (isBlocked(sanitizedClassName)) {
84 | throw new UnsupportedOperationException("Naughty (class): " + sanitizedClassName);
85 | }
86 | if (isBlocked(sanitizedClassName + "#" + name)) {
87 | throw new UnsupportedOperationException(
88 | "Naughty (meth): " + sanitizedClassName + "#" + name
89 | );
90 | }
91 |
92 | // do not check the package if the class or method was explicitely allowed
93 | if (whiteBlackList.isWhitelisted(sanitizedClassName)
94 | || whiteBlackList.isWhitelisted(sanitizedClassName + "#" + name)) {
95 | return;
96 | }
97 |
98 | if (isPackageOrParentBlocked(sanitizedClassName)) {
99 | throw new UnsupportedOperationException("Naughty (pack): " + sanitizedClassName);
100 | }
101 | }
102 |
103 | private String sanitizeClassName(String owner) {
104 | return owner.replace("/", ".");
105 | }
106 |
107 | @Override
108 | public void visitInvokeDynamicInsn(String name, String descriptor, Handle bootstrapMethodHandle,
109 | Object... bootstrapMethodArguments) {
110 | // TODO: 04.04.18 Implement this method
111 | LOGGER.warn("Calling dymn " + name + " " + descriptor + " " + bootstrapMethodHandle);
112 | }
113 | }
114 | }
115 |
--------------------------------------------------------------------------------
/src/main/java/org/togetherjava/discord/server/sandbox/FilteredExecutionControlProvider.java:
--------------------------------------------------------------------------------
1 | package org.togetherjava.discord.server.sandbox;
2 |
3 | import java.lang.reflect.InvocationHandler;
4 | import java.lang.reflect.InvocationTargetException;
5 | import java.lang.reflect.Method;
6 | import java.lang.reflect.Proxy;
7 | import java.util.Map;
8 | import java.util.function.Supplier;
9 | import jdk.jshell.execution.JdiExecutionControlProvider;
10 | import jdk.jshell.spi.ExecutionControl;
11 | import jdk.jshell.spi.ExecutionControlProvider;
12 | import jdk.jshell.spi.ExecutionEnv;
13 |
14 | public class FilteredExecutionControlProvider implements ExecutionControlProvider {
15 |
16 | private final JdiExecutionControlProvider jdiExecutionControlProvider;
17 | private final Supplier executionControlSupplier;
18 |
19 | public FilteredExecutionControlProvider(WhiteBlackList whiteBlackList) {
20 | this.jdiExecutionControlProvider = new JdiExecutionControlProvider();
21 | this.executionControlSupplier = () -> new FilteredExecutionControl(whiteBlackList);
22 | }
23 |
24 | @Override
25 | public String name() {
26 | return "filtered";
27 | }
28 |
29 | @Override
30 | public ExecutionControl generate(ExecutionEnv env, Map parameters)
31 | throws Throwable {
32 | ExecutionControl hijackedExecutionControl = jdiExecutionControlProvider
33 | .generate(env, parameters);
34 | FilteredExecutionControl filteredExecutionControl = executionControlSupplier.get();
35 |
36 | return (ExecutionControl) Proxy.newProxyInstance(
37 | getClass().getClassLoader(),
38 | new Class[]{ExecutionControl.class},
39 | new ExecutionControlDelegatingProxy(filteredExecutionControl, hijackedExecutionControl)
40 | );
41 | }
42 |
43 | private static class ExecutionControlDelegatingProxy implements InvocationHandler {
44 |
45 | private FilteredExecutionControl target;
46 | private ExecutionControl hijackedExecutionControl;
47 |
48 | private ExecutionControlDelegatingProxy(FilteredExecutionControl target,
49 | ExecutionControl hijackedExecutionControl) {
50 | this.target = target;
51 | this.hijackedExecutionControl = hijackedExecutionControl;
52 | }
53 |
54 | @Override
55 | public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
56 | if ("load".equals(method.getName())
57 | && method.getParameterTypes()[0] == ExecutionControl.ClassBytecodes[].class
58 | && args.length != 0) {
59 |
60 | target.load((ExecutionControl.ClassBytecodes[]) args[0]);
61 | }
62 |
63 | // this unwrapping is necessary for JShell to detect that an exception it can handle was thrown
64 | try {
65 | return method.invoke(hijackedExecutionControl, args);
66 | } catch (InvocationTargetException e) {
67 | throw e.getCause();
68 | }
69 | }
70 | }
71 | }
72 |
--------------------------------------------------------------------------------
/src/main/java/org/togetherjava/discord/server/sandbox/JshellSecurityManager.java:
--------------------------------------------------------------------------------
1 | package org.togetherjava.discord.server.sandbox;
2 |
3 | import java.security.Permission;
4 | import java.util.Arrays;
5 |
6 | /**
7 | * The {@link SecurityManager} used to limit JShell's permissions.
8 | */
9 | public class JshellSecurityManager extends SecurityManager {
10 |
11 | private static final String[] WHITELISTED_CLASSES = {
12 | // CallSite is needed for lambdas
13 | "java.lang.invoke.CallSite",
14 | // enum set/map because they do a reflective invocation to get the universe
15 | // let's hope that is actually safe and EnumSet/Map can not be used to invoke arbitrary code
16 | "java.util.EnumSet", "java.util.EnumMap",
17 | // Character.getName accesses a system resource (uniName.dat)
18 | "java.lang.CharacterName",
19 | // Local specific decimal formatting
20 | "java.text.DecimalFormatSymbols"
21 | };
22 |
23 |
24 | @Override
25 | public void checkPermission(Permission perm) {
26 | if (comesFromMe()) {
27 | return;
28 | }
29 |
30 | // lambda init call
31 | if (containsWhitelistedClass()) {
32 | return;
33 | }
34 |
35 | // allow all but Jshell to bypass this
36 | if (comesFromJshell()) {
37 | super.checkPermission(perm);
38 | }
39 | }
40 |
41 | private boolean comesFromJshell() {
42 | return Arrays.stream(getClassContext())
43 | .anyMatch(aClass -> aClass.getName().contains("REPL"));
44 | }
45 |
46 | private boolean comesFromMe() {
47 | return Arrays.stream(getClassContext())
48 | // one frame for this method, one frame for the call to checkPermission
49 | .skip(2)
50 | // see if the security manager appears anywhere else in the context. If so, we initiated
51 | // the call
52 | .anyMatch(aClass -> aClass == getClass());
53 | }
54 |
55 | private boolean containsWhitelistedClass() {
56 | for (Class> aClass : getClassContext()) {
57 | for (String s : WHITELISTED_CLASSES) {
58 | if (s.equals(aClass.getName())) {
59 | return true;
60 | }
61 | }
62 | }
63 | return false;
64 | }
65 | }
66 |
--------------------------------------------------------------------------------
/src/main/java/org/togetherjava/discord/server/sandbox/WhiteBlackList.java:
--------------------------------------------------------------------------------
1 | package org.togetherjava.discord.server.sandbox;
2 |
3 | import java.util.ArrayList;
4 | import java.util.Arrays;
5 | import java.util.List;
6 | import java.util.regex.Pattern;
7 | import org.togetherjava.discord.server.Config;
8 |
9 | /**
10 | * A black- or whitelist.
11 | */
12 | public class WhiteBlackList {
13 |
14 | private List blacklist;
15 | private List whitelist;
16 |
17 | /**
18 | * Creates a new black- or whitelist.
19 | */
20 | public WhiteBlackList() {
21 | this.whitelist = new ArrayList<>();
22 | this.blacklist = new ArrayList<>();
23 | }
24 |
25 | /**
26 | * Adds a new pattern to the blacklist.
27 | *
28 | * @param pattern the pattern
29 | */
30 | public void blacklist(String pattern) {
31 | blacklist.add(Pattern.compile(pattern));
32 | }
33 |
34 | /**
35 | * Adds a new pattern to the whitelist.
36 | *
37 | * @param pattern the pattern
38 | */
39 | public void whitelist(String pattern) {
40 | whitelist.add(Pattern.compile(pattern));
41 | }
42 |
43 | /**
44 | * Returns whether a given input is blocked.
45 | *
46 | * @param input the input
47 | * @return true if it is blocked
48 | */
49 | public boolean isBlocked(String input) {
50 | return matches(blacklist, input) && !matches(whitelist, input);
51 | }
52 |
53 | /**
54 | * Returns whether a given input is whitelisted.
55 | *
56 | * @param input the input
57 | * @return true if it is blocked
58 | */
59 | public boolean isWhitelisted(String input) {
60 | return matches(whitelist, input);
61 | }
62 |
63 | private static boolean matches(List patterns, String input) {
64 | return patterns.stream().anyMatch(pattern -> pattern.matcher(input).matches());
65 | }
66 |
67 | /**
68 | * Creates the white- or blacklist with the values in the config object.
69 | *
70 | * @param config the config object
71 | * @return the created list
72 | */
73 | public static WhiteBlackList fromConfig(Config config) {
74 | String[] blacklist = config.getStringOrDefault("sandbox.blacklist", "")
75 | .split(",");
76 | String[] whitelist = config.getStringOrDefault("sandbox.whitelist", "")
77 | .split(",");
78 |
79 | WhiteBlackList list = new WhiteBlackList();
80 |
81 | Arrays.stream(blacklist).forEach(list::blacklist);
82 | Arrays.stream(whitelist).forEach(list::whitelist);
83 |
84 | return list;
85 | }
86 |
87 | @Override
88 | public String toString() {
89 | return "WhiteBlackList{" +
90 | "blacklist=" + blacklist +
91 | ", whitelist=" + whitelist +
92 | '}';
93 | }
94 | }
95 |
--------------------------------------------------------------------------------
/src/main/resources/bot.properties:
--------------------------------------------------------------------------------
1 | # The prefix for commands
2 | prefix=!jshell
3 | # Th bot token
4 | token=yourtokengoeshere
5 | # How long a JShell session is kept around. Short means the history will be lost earlier,
6 | # but it will need less server resources, if many people use it.
7 | session.ttl=PT15M
8 | # The maximum time a single command can take before it is killed
9 | computation.allotted_time=PT15S
10 | # Whether to auto delete the bot's messages
11 | messages.auto_delete=false
12 | # Defines after what timeout the bot messages should be deleted
13 | messages.auto_delete.duration=PT15M
14 | # The maximum amount of embeds to show for multi-snippet inputs
15 | messages.max_context_display_amount=3
16 | # Blacklisted packages, classes and methods.
17 | # Format for packages "com.package"
18 | # Format for classes "fully.qualified.Name"
19 | # Format for methods "fully.qualified.Name#methodName"
20 | sandbox.blacklist=sun,\
21 | jdk,\
22 | java.lang.reflect,\
23 | java.lang.invoke,\
24 | java.util.concurrent,\
25 | org.togetherjava,\
26 | java.lang.ProcessBuilder,\
27 | java.lang.ProcessHandle,\
28 | java.lang.Runtime,\
29 | java.lang.System#exit,\
30 | java.lang.Thread#sleep,\
31 | java.lang.Thread#wait,\
32 | java.lang.Thread#notify,\
33 | java.lang.Thread#currentThread,\
34 | java.lang.Thread#start
35 | # The packages, classes, and methods to explicitly whitelist.
36 | # Same format as the blacklist above
37 | sandbox.whitelist=java.util.concurrent.atomic,\
38 | java.util.concurrent.Concurrent.*,\
39 | java.util.concurrent..*Queue,\
40 | java.util.concurrent.CopyOnWrite.*,\
41 | java.util.concurrent.ThreadLocalRandom.*
42 | # Commands JShell runs when starting up.
43 | java.startup-command=import java.io.*;\
44 | import java.math.*;\
45 | import java.net.*;\
46 | import java.nio.file.*;\
47 | import java.util.*;\
48 | import java.util.concurrent.*;\
49 | import java.util.function.*;\
50 | import java.util.prefs.*;\
51 | import java.util.regex.*;\
52 | import java.util.stream.*;
--------------------------------------------------------------------------------
/src/main/resources/jshell.policy:
--------------------------------------------------------------------------------
1 | // Restrict what Jshell can run
2 | grant {
3 | permission java.util.RuntimePermission "accessDeclaredMembers";
4 | permission java.util.RuntimePermission "accessClassInPackage";
5 | };
6 |
--------------------------------------------------------------------------------
/src/main/resources/logback.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 | %d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n
10 |
11 |
12 |
13 |
14 |
15 | logs/bot.%d{yyyy-MM-dd}.log
16 | 90
17 |
18 |
19 | UTF-8
20 | %d{yyyy-MM-dd HH:mm:ss.SSS} [%thread] %-5level %-40.40logger{39} : %msg%n
21 |
22 | true
23 |
24 |
25 |
26 | 512
27 |
28 |
29 |
30 |
31 |
32 |
33 |
--------------------------------------------------------------------------------
/src/test/java/org/togetherjava/discord/server/execution/JShellSessionManagerTest.java:
--------------------------------------------------------------------------------
1 | package org.togetherjava.discord.server.execution;
2 |
3 | import org.junit.jupiter.api.AfterEach;
4 | import org.junit.jupiter.api.BeforeEach;
5 | import org.junit.jupiter.api.Test;
6 | import org.togetherjava.discord.server.Config;
7 |
8 | import java.time.Duration;
9 | import java.time.LocalDateTime;
10 | import java.util.Properties;
11 |
12 | import static org.junit.jupiter.api.Assertions.assertEquals;
13 | import static org.junit.jupiter.api.Assertions.assertNotEquals;
14 |
15 | class JShellSessionManagerTest {
16 |
17 | private JShellSessionManager jShellSessionManager;
18 | private static Duration sessionTTL = Duration.ofSeconds(15);
19 |
20 | @BeforeEach
21 | void setUp() {
22 | Properties properties = new Properties();
23 | properties.setProperty("session.ttl", "PT15S");
24 | properties.setProperty("computation.allotted_time", "PT15S");
25 | Config config = new Config(properties);
26 | jShellSessionManager = new JShellSessionManager(config);
27 | }
28 |
29 | @AfterEach
30 | void tearDown() {
31 | jShellSessionManager.shutdown();
32 | }
33 |
34 | @Test
35 | void cachesSessions() {
36 | String userId = "1";
37 | JShellWrapper session = jShellSessionManager.getSessionOrCreate(userId);
38 | JShellWrapper secondCall = jShellSessionManager.getSessionOrCreate(userId);
39 |
40 | assertEquals(session, secondCall, "Sessions differ");
41 | }
42 |
43 | @Test
44 | void createsNewSessionForDifferentUser() {
45 | JShellWrapper session = jShellSessionManager.getSessionOrCreate("1");
46 | JShellWrapper secondCall = jShellSessionManager.getSessionOrCreate("2");
47 |
48 | assertNotEquals(session, secondCall, "Sessions are the same");
49 | }
50 |
51 | @Test
52 | void timesOutSessions() {
53 | String userId = "1";
54 | JShellWrapper session = jShellSessionManager.getSessionOrCreate(userId);
55 |
56 | jShellSessionManager.setTimeProvider(() -> LocalDateTime.now().plus(sessionTTL).plusSeconds(5));
57 | jShellSessionManager.purgeOld();
58 |
59 | assertNotEquals(session, jShellSessionManager.getSessionOrCreate(userId), "Session was not expired");
60 |
61 | // restore old
62 | jShellSessionManager.setTimeProvider(LocalDateTime::now);
63 | }
64 |
65 | @Test
66 | void cachesSessionsOverTime() {
67 | String userId = "1";
68 | JShellWrapper session = jShellSessionManager.getSessionOrCreate(userId);
69 |
70 | jShellSessionManager.setTimeProvider(() -> LocalDateTime.now().plus(sessionTTL).minusSeconds(5));
71 | jShellSessionManager.purgeOld();
72 |
73 | assertEquals(session, jShellSessionManager.getSessionOrCreate(userId), "Session was expired");
74 |
75 | // restore old
76 | jShellSessionManager.setTimeProvider(LocalDateTime::now);
77 | }
78 | }
--------------------------------------------------------------------------------
/src/test/java/org/togetherjava/discord/server/execution/JShellWrapperTest.java:
--------------------------------------------------------------------------------
1 | package org.togetherjava.discord.server.execution;
2 |
3 | import static org.junit.jupiter.api.Assertions.assertEquals;
4 | import static org.junit.jupiter.api.Assertions.assertFalse;
5 | import static org.junit.jupiter.api.Assertions.assertNull;
6 | import static org.junit.jupiter.api.Assertions.assertThrows;
7 | import static org.junit.jupiter.api.Assertions.assertTrue;
8 |
9 | import java.time.Duration;
10 | import java.util.List;
11 | import java.util.Properties;
12 | import java.util.concurrent.Executors;
13 | import java.util.stream.Collectors;
14 | import jdk.jshell.Diag;
15 | import jdk.jshell.Snippet.Status;
16 | import jdk.jshell.SnippetEvent;
17 | import org.junit.jupiter.api.AfterAll;
18 | import org.junit.jupiter.api.BeforeAll;
19 | import org.junit.jupiter.api.Test;
20 | import org.junit.jupiter.params.ParameterizedTest;
21 | import org.junit.jupiter.params.provider.ValueSource;
22 | import org.togetherjava.discord.server.Config;
23 | import org.togetherjava.discord.server.execution.JShellWrapper.JShellResult;
24 |
25 | class JShellWrapperTest {
26 |
27 | private static JShellWrapper wrapper;
28 |
29 | @BeforeAll
30 | static void setupWrapper() {
31 | Properties properties = new Properties();
32 | properties.setProperty("sandbox.blacklist", "java.time");
33 | Config config = new Config(properties);
34 | TimeWatchdog timeWatchdog = new TimeWatchdog(
35 | Executors.newScheduledThreadPool(1),
36 | Duration.ofMinutes(20)
37 | );
38 | wrapper = new JShellWrapper(config, timeWatchdog);
39 | }
40 |
41 | @AfterAll
42 | static void cleanup() {
43 | wrapper.close();
44 | }
45 |
46 | @Test
47 | void reportsCompileTimeError() {
48 | // 1crazy is an invalid variable name
49 | JShellWrapper.JShellResult result = wrapper.eval("1crazy").get(0);
50 |
51 | assertFalse(result.getEvents().isEmpty(), "Found no events");
52 |
53 | for (SnippetEvent snippetEvent : result.getEvents()) {
54 | List diags = wrapper.getSnippetDiagnostics(snippetEvent.snippet())
55 | .collect(Collectors.toList());
56 | assertFalse(diags.isEmpty(), "Has no diagnostics");
57 | assertTrue(diags.get(0).isError(), "Diagnostic is no error");
58 | }
59 | }
60 |
61 | @Test
62 | void correctlyComputesExpression() {
63 | JShellWrapper.JShellResult result = wrapper.eval("1+1").get(0);
64 |
65 | assertEquals(result.getEvents().size(), 1, "Event count is not 1");
66 |
67 | SnippetEvent snippetEvent = result.getEvents().get(0);
68 |
69 | assertNull(snippetEvent.exception(), "An exception occurred");
70 |
71 | assertEquals("2", snippetEvent.value(), "Calculation was wrong");
72 | }
73 |
74 | @Test
75 | void savesHistory() {
76 | wrapper.eval("int test = 1+1;");
77 | JShellWrapper.JShellResult result = wrapper.eval("test").get(0);
78 |
79 | assertEquals(result.getEvents().size(), 1, "Event count is not 1");
80 |
81 | SnippetEvent snippetEvent = result.getEvents().get(0);
82 |
83 | assertNull(snippetEvent.exception(), "An exception occurred");
84 |
85 | assertEquals("2", snippetEvent.value(), "Calculation was wrong");
86 | }
87 |
88 | @Test
89 | void blocksPackage() {
90 | assertThrows(
91 | UnsupportedOperationException.class,
92 | () -> wrapper.eval("java.time.LocalDateTime.now()"),
93 | "No exception was thrown when accessing a blocked package."
94 | );
95 | }
96 |
97 | @ParameterizedTest(name = "Accessing \"{0}\" should fail")
98 | @ValueSource(strings = {
99 | "/opt",
100 | "~",
101 | "/tmp/"
102 | })
103 | void blocksFileAccess(String fileName) {
104 | JShellResult result = wrapper.eval("new java.io.File(\"" + fileName + "\").listFiles()").get(0);
105 |
106 | if (!allFailed(result)) {
107 | printSnippetResult(result);
108 | }
109 |
110 | assertTrue(
111 | allFailed(result),
112 | "Not all snippets were rejected when accessing a file."
113 | );
114 | }
115 |
116 | @Test
117 | void blocksNetworkIo() {
118 | JShellResult result = wrapper
119 | .eval("new java.net.URL(\"https://duckduckgo.com\").openConnection().connect()")
120 | .get(0);
121 |
122 | if (!allFailed(result)) {
123 | printSnippetResult(result);
124 | }
125 |
126 | assertTrue(
127 | allFailed(result),
128 | "Not all snippets were rejected when doing network I/O."
129 | );
130 | }
131 |
132 | @Test
133 | void blocksResettingSecurityManager() {
134 | JShellResult result = wrapper
135 | .eval("System.setSecurityManager(null)")
136 | .get(0);
137 |
138 | if (!allFailed(result)) {
139 | printSnippetResult(result);
140 | }
141 |
142 | assertTrue(
143 | allFailed(result),
144 | "Not all snippets were rejected when resetting the security manager."
145 | );
146 | }
147 |
148 | @Test()
149 | void doesNotEnterInfiniteLoopWhenRunningInvalidMethod() {
150 | JShellResult result = wrapper
151 | .eval("void beBad() {\n"
152 | + "try {\n"
153 | + "throw null;\n"
154 | + "catch (Throwable e) {\n"
155 | + " e.printStackTrace()\n"
156 | + "}\n"
157 | + "}\n")
158 | .get(0);
159 |
160 | if (!allFailed(result)) {
161 | printSnippetResult(result);
162 | }
163 |
164 | assertTrue(
165 | allFailed(result),
166 | "Not all snippets were rejected when checking for a timeout."
167 | );
168 | }
169 |
170 | private boolean allFailed(JShellResult result) {
171 | return result.getEvents().stream()
172 | .allMatch(snippetEvent ->
173 | snippetEvent.status() == Status.REJECTED
174 | || snippetEvent.exception() != null
175 | );
176 | }
177 |
178 | private void printSnippetResult(JShellResult result) {
179 | for (SnippetEvent event : result.getEvents()) {
180 | System.out.println(event);
181 | }
182 | }
183 | }
--------------------------------------------------------------------------------
/src/test/java/org/togetherjava/discord/server/io/StringOutputStreamTest.java:
--------------------------------------------------------------------------------
1 | package org.togetherjava.discord.server.io;
2 |
3 | import org.apache.commons.lang3.StringUtils;
4 | import org.junit.jupiter.api.Test;
5 |
6 | import java.io.IOException;
7 | import java.nio.charset.StandardCharsets;
8 |
9 | import static org.junit.jupiter.api.Assertions.assertEquals;
10 |
11 | class StringOutputStreamTest {
12 |
13 | @Test
14 | void capturesOutput() throws IOException {
15 | String test = "hello world";
16 | checkToString(test, test, test.length(), Integer.MAX_VALUE);
17 | }
18 |
19 | @Test
20 | void truncates() throws IOException {
21 | checkToString("hello", "hell", 5, 4);
22 | }
23 |
24 | @Test
25 | void survivesBufferExpansion() throws IOException {
26 | final int length = 10_000;
27 | String test = StringUtils.repeat("A", length);
28 |
29 | checkToString(test, test, length, Integer.MAX_VALUE);
30 | }
31 |
32 | @Test
33 | void survivesBufferExpansionAndTruncates() throws IOException {
34 | final int length = 10_000;
35 | String test = StringUtils.repeat("A", length);
36 | String expected = StringUtils.repeat("A", 4000);
37 |
38 | checkToString(test, expected, length, 4_000);
39 | }
40 |
41 | private void checkToString(String input, String expected, int byteCount, int maxSize) throws IOException {
42 | StringOutputStream stringOutputStream = new StringOutputStream(maxSize);
43 |
44 | byte[] bytes = input.getBytes(StandardCharsets.US_ASCII);
45 |
46 | assertEquals(byteCount, bytes.length, "Somehow ASCII has changed?");
47 |
48 | stringOutputStream.write(bytes);
49 |
50 | assertEquals(expected, stringOutputStream.toString(), "Stored output differed.");
51 | }
52 | }
--------------------------------------------------------------------------------
/src/test/java/org/togetherjava/discord/server/rendering/TruncationRendererTest.java:
--------------------------------------------------------------------------------
1 | package org.togetherjava.discord.server.rendering;
2 |
3 | import static org.junit.jupiter.api.Assertions.assertEquals;
4 |
5 | import net.dv8tion.jda.api.entities.MessageEmbed;
6 | import org.apache.commons.lang3.StringUtils;
7 | import org.junit.jupiter.api.Test;
8 |
9 | class TruncationRendererTest {
10 |
11 | @Test
12 | void truncatesNewLines() {
13 | String string = StringUtils.repeat("\n", 40) + "some stuff";
14 |
15 | String rendered = RenderUtils.truncateAndSanitize(string, MessageEmbed.VALUE_MAX_LENGTH);
16 |
17 | int newLines = StringUtils.countMatches(rendered, '\n');
18 | assertEquals(RenderUtils.NEWLINE_MAXIMUM, newLines, "Expected 10 newlines");
19 | }
20 |
21 | @Test
22 | void keepsNewlines() {
23 | String string = StringUtils.repeat("\n", RenderUtils.NEWLINE_MAXIMUM) + "some stuff";
24 |
25 | String rendered = RenderUtils.truncateAndSanitize(string, MessageEmbed.VALUE_MAX_LENGTH);
26 |
27 | int newLines = StringUtils.countMatches(rendered, '\n');
28 | assertEquals(RenderUtils.NEWLINE_MAXIMUM, newLines, "Expected 10 newlines.");
29 | }
30 | }
--------------------------------------------------------------------------------
/src/test/java/org/togetherjava/discord/server/sandbox/FilteredExecutionControlTest.java:
--------------------------------------------------------------------------------
1 | package org.togetherjava.discord.server.sandbox;
2 |
3 | import static org.junit.jupiter.api.Assertions.assertFalse;
4 | import static org.junit.jupiter.api.Assertions.assertTrue;
5 |
6 | import java.util.Collection;
7 | import java.util.List;
8 | import java.util.Map;
9 | import jdk.jshell.JShell;
10 | import jdk.jshell.SnippetEvent;
11 | import org.junit.jupiter.api.Test;
12 |
13 | class FilteredExecutionControlTest {
14 |
15 | @Test
16 | void testBlockPackage() {
17 | JShell jshell = getJShell(List.of("java.time"));
18 |
19 | assertTrue(
20 | failed(jshell, "java.time.LocalDateTime.now()"),
21 | "Was able to access a method in the blocked package."
22 | );
23 | assertTrue(
24 | failed(jshell, "java.time.format.DateTimeFormatter.ISO_DATE"),
25 | "Was able to access a field in a subpackage of a blocked package."
26 | );
27 | assertTrue(
28 | failed(jshell, "java.time.format.DateTimeFormatter.ISO_DATE.getChronology()"),
29 | "Was able to access a method in a subpackage of a blocked package."
30 | );
31 | assertFalse(
32 | failed(jshell, "Math.PI"),
33 | "Was not able to access a field in another package."
34 | );
35 | assertFalse(
36 | failed(jshell, "Math.abs(5)"),
37 | "Was not able to access a field in another package."
38 | );
39 | }
40 |
41 | @Test
42 | void testBlockClass() {
43 | JShell jshell = getJShell(List.of("java.time.LocalDate"));
44 |
45 | assertTrue(
46 | failed(jshell, "java.time.LocalDate.now()"),
47 | "Was able to access a method in the blocked class."
48 | );
49 | assertFalse(
50 | failed(jshell, "java.time.LocalDateTime.now()"),
51 | "Was not able to access another class with the same prefix."
52 | );
53 | assertFalse(
54 | failed(jshell, "java.time.format.DateTimeFormatter.ISO_DATE"),
55 | "Was not able to access a field in a not blocked class."
56 | );
57 | assertFalse(
58 | failed(jshell, "java.time.format.DateTimeFormatter.ISO_DATE.getChronology()"),
59 | "Was not able to access a method in a not blocked class."
60 | );
61 | }
62 |
63 | @Test
64 | void testBlockMethod() {
65 | JShell jshell = getJShell(List.of("java.time.LocalDate#now"));
66 |
67 | assertTrue(
68 | failed(jshell, "java.time.LocalDate.now()"),
69 | "Was able to access a blocked method."
70 | );
71 | assertFalse(
72 | failed(jshell, "java.time.LocalDateTime.now()"),
73 | "Was not able to access a method with the same name."
74 | );
75 | assertFalse(
76 | failed(jshell, "Math.abs(5)"),
77 | "Was not able to access a method in a not blocked class."
78 | );
79 | assertFalse(
80 | failed(jshell, "java.time.format.DateTimeFormatter.ISO_DATE.getChronology()"),
81 | "Was not able to access a method in a not blocked class."
82 | );
83 | }
84 |
85 | private JShell getJShell(Collection blockedPackages) {
86 | WhiteBlackList blackList = new WhiteBlackList();
87 | blockedPackages.forEach(blackList::blacklist);
88 |
89 | return JShell.builder()
90 | .executionEngine(
91 | new FilteredExecutionControlProvider(blackList),
92 | Map.of()
93 | )
94 | .build();
95 | }
96 |
97 | private boolean failed(JShell jshell, String command) {
98 | try {
99 | for (SnippetEvent event : jshell.eval(command)) {
100 | System.out.println("Got a value: " + event.value());
101 | }
102 | } catch (UnsupportedOperationException e) {
103 | return true;
104 | }
105 | return false;
106 | }
107 | }
--------------------------------------------------------------------------------