Web 编写Java代码后,在Java环境中运行

内容纲要

业务描述

  • 在低代码平台开发过程中,经常需要前端页面提供代码编辑器编辑代码,编辑完成之后后端运行代码; 如下图所示:

实现流程

  1. 前端编写的代码转化为.java文件
  2. 将.java文件编译成功.class文件
  3. 使用类加载器加载.class文件
  4. 获取该类的方法
  5. 执行方法

编写难点

类的卸载

  • 在Java中同一个类加载器只能加载一个类,当前端代码进行debug时, 需要反复的生成同一个类,所以需要对类进行卸载

    解决方案

    参考Tomcat中jsp的动态加载模式, jsp最终的呈现是Java将jsp编写成为一个类之后,再向前端输出对应的内容, 当我们在jsp中修改了内容后是可以不重启服务看到修改的内容

  • 在Tomcat的动态加载jsp中是每一个jsp单独由一个新的类加载器去加载,当类加载器不同,则可以在不同的类加载其中加载相同类名,内容不同的类,实现类的动态加载

  • 本文实现, 每次需要运行一个类的时候,先创建一个对应的类加载器,单独加载这一次代码, 当方法运行完整之后将类加载器置为null, 当类加载器被回收的时候, 加载的类也同步回收

并发控制

  • 并发控制采用的为分段锁的方式, 使用一个Map<String, ReentrantLock>(key: 类名; value : 可重入锁)来实现基于类名的分段锁方式, 保证对一个类的操作是只能为单线程

线上运行与缓存

  • 使用 Map<String, CacheWrapper> classCache = new ConcurrentHashMap<>()来缓存加载好的类信息和对应该类的类加载器, 当系统启动后,如果当前缓存中不存在该类则去加载, 如果存在直接使用, 提高了运行效率, 并且为直接用Java反射调用,与本地方法编写效率几乎相同,代码如下
            Class<?> loadClass = null;
            //从缓存中获取
            CacheWrapper wrapper = classCache.get(className);
            //缓存中不存在
            if (wrapper == null) {
                //加载
                wrapper = getLoadClass(code, className);
            }
            //获取到类信息
            loadClass = wrapper.getClazz();

Java内部执行其他代码安全控制

  • Java中提供沙箱模式, 开启沙箱模式之后可以控制用户编写代码的各种权限问题, 只需要创建一个类去继承SecurityManager类,然后在创建的类中进行全新啊控制即可, 本文的类为SandboxSecurityManager
  • 具体使用,只需要将创建的类使用System.setSecurityManager(securityManager);开启沙箱模式, 运行用户代码结束之后使用 System.setSecurityManager(null);清空安全检查即可

代码展示

github 代码地址
gitee 代码地址

  • JavaCodeEngine.java java代码执行类
package cn.fateverse.common.code.engine;

import cn.fateverse.common.code.console.ConsoleCapture;
import cn.fateverse.common.code.exception.SandboxClassNotFoundException;
import cn.fateverse.common.code.lock.SegmentLock;
import cn.fateverse.common.code.model.EngineResult;
import cn.fateverse.common.code.sandbox.SandboxClassLoader;
import cn.fateverse.common.code.sandbox.SandboxSecurityManager;
import cn.fateverse.common.core.exception.CustomException;
import lombok.Getter;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.ObjectUtils;

import javax.tools.JavaCompiler;
import javax.tools.ToolProvider;
import java.io.*;
import java.lang.reflect.Method;
import java.net.MalformedURLException;
import java.net.URL;
import java.net.URLClassLoader;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/**
 * @author Clay
 * @date 2023-10-24
 */
@Slf4j
public class JavaCodeEngine {

    private final String JAVA_SUFFIX = ".java";

    private final String CLASS_SUFFIX = ".class";

    private final String CLASS_PATH;

    private final URL url;

    private final Map<String, CacheWrapper> classCache = new ConcurrentHashMap<>();

    private final SandboxSecurityManager securityManager = new SandboxSecurityManager(classCache);

    // 获取Java编译器
    private final JavaCompiler compiler = ToolProvider.getSystemJavaCompiler();

    public JavaCodeEngine(String classPath) {
        try {
            CLASS_PATH = classPath;
            File file = new File(CLASS_PATH);
            if (!file.exists()) {
                file.mkdirs();
            }
            url = file.toURI().toURL();
        } catch (MalformedURLException e) {
            throw new RuntimeException(e);
        }
    }

    /**
     * 用于在开发环境中执行代码的私有方法。
     *
     * @param code       需要执行的代码字符串
     * @param className  类名
     * @param methodName 方法名
     * @param args       参数数组
     * @return 执行结构
     */
    @SneakyThrows
    public EngineResult mockExecute(String code, String className, String methodName, Object[] args) {
        Class<?> loadClass = null;
        try {
            // 加锁,确保类只加载一次
            loadClass = SegmentLock.lock(className, () -> {
                URLClassLoader tempClassLoader = null;
                try {
                    // 创建一个URLClassLoader,用于加载代码字符串
                    tempClassLoader = new URLClassLoader(new URL[]{url});
                    // 编译代码字符串为类
                    return compilerClass(className, code, tempClassLoader);
                } catch (Exception e) {
                    e.printStackTrace();
                    if (e instanceof CustomException) {
                        throw (CustomException) e;
                    }
                    // 异常处理,并抛出自定义的SandboxClassNotFoundException异常
                    throw new SandboxClassNotFoundException(e.getMessage());
                } finally {
                    if (tempClassLoader != null) {
                        tempClassLoader = null;
                    }
                }
            });
            // 获取需要执行的方法
            Method method = getMethod(methodName, loadClass);
            // 设置安全检查器
            System.setSecurityManager(securityManager);
            // 执行方法并返回结果
            return ConsoleCapture.capture(() -> method.invoke(null, args));
        } catch (CustomException e) {
            EngineResult result = new EngineResult();
            result.setSuccess(Boolean.FALSE);
            result.setConsole(e.getMessage());
            return result;
        } finally {
            // 清空安全检查器
            System.setSecurityManager(null);
            if (loadClass != null) {
                loadClass = null;
            }
            // 删除生成的java文件
            File javaFile = new File(CLASS_PATH + className + JAVA_SUFFIX);
            if (javaFile.exists()) {
                javaFile.delete();
            }
            // 删除生成的class文件
            File classFile = new File(CLASS_PATH + className + CLASS_SUFFIX);
            if (classFile.exists()) {
                classFile.delete();
            }
            // 执行垃圾回收
            System.gc();
        }
    }

    /**
     * 线上环境执行
     *
     * @param code       需要执行的代码字符串
     * @param className  类名
     * @param methodName 方法名
     * @param args       参数数组
     * @return 执行结构
     */
    public Object execute(String code, String className, String methodName, Object[] args) {
        try {
            Class<?> loadClass = null;
            //从缓存中获取
            CacheWrapper wrapper = classCache.get(className);
            //缓存中不存在
            if (wrapper == null) {
                //加载
                wrapper = getLoadClass(code, className);
            }
            //获取到类信息
            loadClass = wrapper.getClazz();
            //获取方法
            Method method = getMethod(methodName, loadClass);
            //开启安全模式
            System.setSecurityManager(securityManager);
            //执行方法
            return method.invoke(null, args);
        } catch (Exception e) {
            remove(className);
            e.printStackTrace();
        } finally {
            System.setSecurityManager(null);
        }
        return null;
    }

    /**
     * 获取到方法
     *
     * @param methodName 方法名称
     * @param loadClass  类信息
     * @return 方法对象
     */
    private Method getMethod(String methodName, Class<?> loadClass) {
        Method method = null;
        for (Method declaredMethod : loadClass.getDeclaredMethods()) {
            if (declaredMethod.getName().equals(methodName)) {
                method = declaredMethod;
            }
        }
        return method;
    }

    /**
     * 获取到编译完成的Class对象
     *
     * @param code      需要编译的代码
     * @param className 类名
     * @return 编译后的Java对象
     */
    private CacheWrapper getLoadClass(String code, String className) {
        //使用分段锁,提高效率,放多并发情况下多次对同一个类进行编译
        return SegmentLock.lock(className, () -> {
            try {
                URLClassLoader classLoader = new SandboxClassLoader(new URL[]{url});
                //执行编译
                Class<?> tempClass = compilerClass(className, code, classLoader);
                //创建缓存包装对象
                CacheWrapper wrapper = new CacheWrapper(tempClass, classLoader);
                //将编译之后的类对象放在缓存中,提高线上环境的运行效率
                classCache.put(className, wrapper);
                return wrapper;
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        });
    }

    /**
     * 编译Java代码
     *
     * @param className   类名
     * @param code        Java代码
     * @param classLoader 类加载器
     * @return 编译完成的类对象
     */
    private Class<?> compilerClass(String className, String code, URLClassLoader classLoader) {
        File tempFile = new File(CLASS_PATH + className + JAVA_SUFFIX);
        try (FileWriter writer = new FileWriter(tempFile)) {
            writer.write(code);
            writer.close();
            ByteArrayOutputStream errorStream = new ByteArrayOutputStream(10240);
            // 编译.java文件
            compiler.run(null, null, errorStream, tempFile.getPath());
            String trace = errorStream.toString();//存放控制台输出的字符串
            if (!ObjectUtils.isEmpty(trace)) {
                trace = trace.replace(CLASS_PATH + className + ".", "");
                throw new CustomException("编译错误: " + trace);
            }
            return classLoader.loadClass(className);
        } catch (Exception e) {
            e.printStackTrace();
            if (e instanceof CustomException) {
                throw (CustomException) e;
            }
            throw new CustomException("执行或者编辑错误!");
        }
    }

    /**
     * 删除类
     *
     * @param className 删除类
     * @return 删除结果
     */
    public Boolean remove(String className) {
        return SegmentLock.lock(className, () -> {
            CacheWrapper wrapper = classCache.get(className);
            if (wrapper != null) {
                classCache.remove(className);
                wrapper.remove();
            }
            //进行gc 垃圾挥手
            System.gc();
            //删除Java文件
            File javaFile = new File(CLASS_PATH + className + JAVA_SUFFIX);
            if (javaFile.exists()) {
                javaFile.delete();
            }
            //删除class文件
            File classFile = new File(CLASS_PATH + className + CLASS_SUFFIX);
            if (classFile.exists()) {
                classFile.delete();
            }
            return true;
        });
    }

    @Getter
    public static class CacheWrapper {

        private Class<?> clazz;

        private URLClassLoader classLoader;

        public CacheWrapper(Class<?> clazz, URLClassLoader classLoader) {
            this.clazz = clazz;
            this.classLoader = classLoader;
        }

        public void remove(){
            clazz = null;
            classLoader = null;
        }
    }
}
  • ConsoleCapture.java控制台日志捕获类
package cn.fateverse.common.code.console;

import cn.fateverse.common.code.model.EngineResult;
import cn.fateverse.common.core.exception.CustomException;

import java.io.ByteArrayOutputStream;
import java.io.PrintStream;

/**
 * 控制台输出捕获
 *
 * @author Clay
 * @date 2024/4/22 17:08
 */
public class ConsoleCapture {

    /**
     * 捕获方法
     *
     * @param task 任務
     * @return 返回结果
     */
    public static EngineResult capture(Task task) {
        ByteArrayOutputStream baos = new ByteArrayOutputStream();
        PrintStream oldOut = System.out;
        System.setOut(new PrintStream(baos));
        Object result;
        String capturedOutput;
        try {
            result = task.execute();
        } catch (Exception e) {
            if (e instanceof CustomException) {
                throw (CustomException) e;
            }
            throw new RuntimeException(e);
        } finally {
            System.setOut(oldOut);
            // 从捕获的字节数组输出流中获取打印的文本
            capturedOutput = baos.toString();
        }
        return new EngineResult(result, capturedOutput);
    }
    public interface Task {
        Object execute() throws Exception;
    }
}
  • EngineResult.java捕获输出对象
package cn.fateverse.common.code.model;

import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;

/**
 * @author Clay
 * @date 2024/4/22 17:10
 */
@Data
@AllArgsConstructor
@NoArgsConstructor
public class EngineResult {

    private Object result;

    private String console;

    private Boolean success;

    public EngineResult(Object result, String console) {
        success = true;
        this.result = result;
        this.console = console;
    }
}
  • SegmentLock.java分段锁对象
package cn.fateverse.common.code.lock;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Supplier;

/**
 * 分段锁对象
 *
 * @author Clay
 * @date 2023-10-25
 */
public class SegmentLock {

    private static final Map<String, ReentrantLock> lockMap = new ConcurrentHashMap<>();

    /**
     * 分段锁
     *
     * @param key      锁名称
     * @param supplier 需要执行的函数
     * @param <T>      接收泛型
     * @return 执行后的结果
     */
    public static <T> T lock(String key, Supplier<T> supplier) {
        ReentrantLock lock = lockMap.get(key);
        if (lock == null) {
            lock = lockMap.get(key);
            if (lock == null) {
                synchronized (lockMap) {
                    lock = new ReentrantLock();
                    lockMap.put(key, lock);
                }
            }
        }
        lock.lock();
        try {
            return supplier.get();
        } finally {
            lock.unlock();
        }
    }
}
  • SandboxSecurityManager.java 沙箱安全类
package cn.fateverse.common.code.sandbox;

import cn.fateverse.common.code.engine.JavaCodeEngine;

import java.io.FilePermission;
import java.lang.reflect.ReflectPermission;
import java.security.Permission;
import java.util.Map;
import java.util.PropertyPermission;
import java.util.Set;

public class SandboxSecurityManager extends SecurityManager {

    private final Map<String, JavaCodeEngine.CacheWrapper> classCache;

    public SandboxSecurityManager(Map<String, JavaCodeEngine.CacheWrapper> classCache) {
        this.classCache = classCache;
    }

    @Override
    public void checkPermission(Permission perm) {
        if (isSandboxCode(perm)) {
            if (!isAllowedPermission(perm)) {
                throw new SecurityException("Permission denied " + perm);
            }
        }
    }

    private boolean isSandboxCode(Permission perm) {
        Set<String> classKeySet = classCache.keySet();
        for (String key : classKeySet) {
            if (perm.getName().contains(key)) {
                return true;
            }
        }
        return false;
    }

    private boolean isAllowedPermission(Permission permission) {
        //权限:用于校验文件系统访问权限,包括读取、写入、删除文件,以及目录操作。权限名称可能包括文件路径和操作,如 "read", "write", "delete", "execute" 等。
        if (permission instanceof FilePermission) {
            System.out.println("触发文件读写权限");
            return false;
        }
        //权限:用于校验运行时权限,如程序启动、关闭虚拟机等。您可以根据名称进行控制,如 "exitVM"、"setSecurityManager" 等。
        if (permission instanceof RuntimePermission) {
            System.out.println("用于校验运行时权限");
            return false;
        }
        //权限:用于校验Java反射操作的权限,如 `suppressAccessChecks`、`newProxyInPackage` 等。
        if (permission instanceof ReflectPermission) {
            System.out.println("用于校验Java反射操作的权限");
            return false;
        }
        //权限:用于校验系统属性的权限,包括读取和设置系统属性。权限名称通常以属性名称和操作(如 "read" 或 "write")表示。
        if (permission instanceof PropertyPermission) {
            System.out.println("用于校验系统属性的权限");
            return false;
        }
        // 权限:用于校验数据库访问权限,包括连接数据库、执行SQL语句等。权限名称通常与数据库URL和操作相关。
//        if (permission instanceof SQLPermission) {
//            return false;
//
//        }
        //权限:用于校验网络套接字的权限,包括连接到特定主机和端口。权限名称通常以主机名和端口号的形式表示,如 "www.example.com:80".
//        if (permission instanceof SocketPermission) {
//            return false;
//
//        }
        //权限:用于校验安全管理器操作的权限,如 `createAccessControlContext`、`setPolicy` 等。
//        if (permission instanceof SecurityPermission) {
//            return false;
//
//        }
//        //序列化
//        if (permission instanceof SerializablePermission) {
//            return false;
//        }
//        System.out.println(permission);
        return true;
    }

}
THE END
分享
二维码
< <上一篇
下一篇>>