自己动手实现一个RPC框架(simple-rpc)

使用 Zookeeper + Netty + Spring + Protostuff 实现一个简单的 RPC 框架

概述

由于各服务部署在不同机器,服务间的调用免不了网络通信过程,服务消费方每调用一个服务都要写一坨网络通信相关的代码,不仅复杂而且极易出错。

如果有一种方式能让我们像调用本地服务一样调用远程服务,而让调用者对网络通信这些细节透明,那么将大大提高生产力,比如服务消费方在执行helloWorldService.sayHello(“test”) 时,实质上调用的是远端的服务。这种方式其实就是 RPC(Remote Procedure Call Protocol),在各大互联网公司中被广泛使用,如阿里巴巴的 hsf、dubbo(开源)、Facebook的thrift(开源)、Google grpc(开源)、Twitter 的 finagle(开源)等。

项目地址:https://github.com/hczs/simple-rpc

所用技术栈

  1. 首先我们需要发网络请求进行调用,这里使用网络框架 Netty
  2. 然后调用之间的消息需要序列化和反序列化,这里使用序列化框架 protostuff
  3. 再然后呢我们需要知道服务提供方地址是多少,也就是 服务发现/ 服务注册,我们使用 Zookeeper 来管理服务,使用 Curator 框架来操作 Zookeeper
  4. 我们使用 Spring 来方便的管理 Bean 进行随意的注入使用,以及配置文件值注入
  5. 使用 lombok 来精简代码,方便快速开发
  6. 使用 objenesis 库来优化我们反序列化 请求 / 响应对象的速度
  7. 使用 cglib 来优化我们接收响应处理执行方法的速度
  8. 使用 commons.lang3 库中的一些常用工具类

RPC 框架都帮我们干了些什么

要让网络通信细节对使用者透明,我们需要对通信细节进行封装,我们先看下一个RPC调用的流程涉及到哪些通信细节

  1. 服务消费方(client)调用以本地调用方式调用服务;
  2. client stub接收到调用后负责将方法、参数等组装成能够进行网络传输的消息体;
  3. client stub找到服务地址,并将消息发送到服务端;
  4. server stub收到消息后进行解码;
  5. server stub根据解码结果调用本地的服务;
  6. 本地服务执行并将结果返回给server stub;
  7. server stub将返回结果打包成消息并发送至消费方;
  8. client stub接收到消息,并进行解码;
  9. 服务消费方得到最终结果。

RPC的目标就是要2~8这些步骤都封装起来,让用户对这些细节透明。

简单来说:服务端,启动就自动注册服务,等待客户端调用

如何将封装上述2~8的步骤

可以通过动态代理,生成代理对象,然后再 invoke 方法的时候进行网络请求,达到对细节的封装,即下面的 invoke 方法:

/**
 * RPC 客户端远程调用具体处理器
 * @author: houcheng
 * @date: 2022/6/1 15:56:36
 */
@Slf4j
public class RemoteInvocationHandler implements InvocationHandler {

    /**
     * 服务发现对象
     */
    private final ServiceDiscovery serviceDiscovery;

    public RemoteInvocationHandler(ServiceDiscovery serviceDiscovery) {
        this.serviceDiscovery = serviceDiscovery;
    }

    @Override
    public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
        RpcRequest rpcRequest = new RpcRequest();
        rpcRequest.setRequestId(UUID.randomUUID().toString());
        rpcRequest.setInterfaceName(method.getDeclaringClass().getName());
        rpcRequest.setMethodName(method.getName());
        rpcRequest.setParameterTypes(method.getParameterTypes());
        rpcRequest.setParameters(args);
        // 服务发现 找一个具体能够处理请求的服务地址
        String serviceAddress = serviceDiscovery.discover();
        if (serviceAddress == null) {
            return null;
        }
        String[] hostAndPort = serviceAddress.split(":");
        String host = hostAndPort[0];
        int port = Integer.parseInt(hostAndPort[1]);
        // 发送请求
        RpcClient rpcClient = new RpcClient(host, port);
        RpcResponse rpcResponse = rpcClient.sendRequest(rpcRequest);
        log.info("rpcResponse: {}", rpcResponse);
        return rpcResponse.getResult();
    }
}

请求响应的消息数据结构设计

请求消息

  1. 接口名称:传过去,服务端知道你想要调用哪个接口
  2. 方法名:接口可能有多个方法,这个也是必须滴
  3. 参数类型 & 参数值:参数类型有很多,比如有bool、int、long、double、string、map、list,甚至如struct(class);以及相应的参数值;
  4. 超时时间:不能一直请求阻塞着(这里暂未实现)
  5. requestID:标识唯一请求id,这样请求和响应才能对得上号,要不然发出去多个请求,返回来多个响应,都分不清了
/**
 * RPC 请求对象
 * @author: houcheng
 * @date: 2022/5/31 16:07:06
 */
@Data
public class RpcRequest {

    /**
     * 请求id
     */
    private String requestId;

    /**
     * 接口名
     */
    private String interfaceName;

    /**
     * 方法名
     */
    private String methodName;

    /**
     * 参数类型
     */
    private Class<?>[] parameterTypes;

    /**
     * 参数值
     */
    private Object[] parameters;
}

响应消息

  1. 返回值
  2. 状态码
  3. requestID
  4. 异常信息
/**
 * RPC 响应对象
 * @author: houcheng
 * @date: 2022/5/31 16:09:27
 */
@Data
public class RpcResponse {

    /**
     * 请求ID
     */
    private String requestId;

    /**
     * 响应结果
     */
    private Object result;

    /**
     * 状态码
     */
    private Integer code;

    /**
     * 错误信息
     */
    private Throwable error;
}

消息编解码 / 序列化

为什么要序列化?因为序列化后会方便进行网络之间传输数据

序列化(编码):对象转换为二进制数据

反序列化(解码):二进制数据转换为对象

序列化

/**
 * 序列化工具类 使用 Protostuff 实现
 * @author :hc
 * @date :Created in 2022/5/31 19:48
 * @modified :
 */
public class SerializationUtil {

    private static Map<Class<?>, Schema<?>> cachedSchema = new ConcurrentHashMap<>();

    private static Objenesis objenesis = new ObjenesisStd(true);

    private SerializationUtil() {

    }

    @SuppressWarnings("unchecked")
    private static <T> Schema<T> getSchema(Class<T> cls) {
        return (Schema<T>) cachedSchema.computeIfAbsent(cls, key -> RuntimeSchema.createFrom(cls));
    }

    @SuppressWarnings("unchecked")
    public static <T> byte[] serialize(T obj) {
        Class<T> cls = (Class<T>) obj.getClass();
        LinkedBuffer buffer = LinkedBuffer.allocate(LinkedBuffer.DEFAULT_BUFFER_SIZE);
        try {
            Schema<T> schema = getSchema(cls);
            return ProtobufIOUtil.toByteArray(obj, schema, buffer);
        } catch (Exception e) {
            throw new IllegalStateException(e.getMessage(), e);
        } finally {
            buffer.clear();
        }
    }

    public static <T> T deserialize(byte[] data, Class<T> cls) {
        try {
            T message = objenesis.newInstance(cls);
            Schema<T> schema = getSchema(cls);
            ProtobufIOUtil.mergeFrom(data, message, schema);
            return message;
        } catch (Exception e) {
            throw new IllegalStateException(e.getMessage(), e);
        }
    }
}

编解码处理器

/**
 * @author :hc
 * @date :Created in 2022/5/31 19:40
 * @modified :
 */
public class RpcDecoder extends ByteToMessageDecoder {

    private Class<?> genericClass;

    public RpcDecoder(Class<?> genericClass) {
        this.genericClass = genericClass;
    }

    @Override
    protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
        if (in.readableBytes() < 4 ) {
            return;
        }

        in.markReaderIndex();
        int dataLength = in.readInt();

        if (dataLength < 0) {
            ctx.close();
        }

        if (in.readableBytes() < dataLength) {
            in.resetReaderIndex();
            return;
        }

        byte[] data = new byte[dataLength];
        in.readBytes(data);

        Object obj = SerializationUtil.deserialize(data, genericClass);
        out.add(obj);
    }
}
/**
 * @author :hc
 * @date :Created in 2022/5/31 19:40
 * @modified :
 */
public class RpcEncoder extends MessageToByteEncoder {

    private Class<?> genericClass;

    public RpcEncoder(Class<?> genericClass) {
        this.genericClass = genericClass;
    }

    @Override
    public void encode(ChannelHandlerContext ctx, Object in, ByteBuf out) {
        if (genericClass.isInstance(in)) {
            byte[] data = SerializationUtil.serialize(in);
            out.writeInt(data.length);
            out.writeBytes(data);
        }
    }
}

通信

现在消息有了,该怎么发出去呢?就涉及到网络通信了,这里使用的是网络框架 Netty

服务端启动 Netty Server

/**
 * RPC Server / Netty Server
 * @author: houcheng
 * @date: 2022/5/31 16:57:13
 */
@Slf4j
@Component
@PropertySource("classpath:simple-rpc.properties")
public class RpcServer {

    /**
     * 注册中心地址
     */
    @Value("${registry.address:}")
    private String registryAddress;

    /**
     * RPC 服务启动地址
     */
    @Value("${service.address:}")
    private String serviceAddress;

    /**
     * 存放接口名和服务对象(实现类对象)之间的映射关系
     */
    private final Map<String, Object> handlerMap = new HashMap<>();

    /**
     * RpcServer 启动
     * @param applicationContext ApplicationContext
     */
    public void start(ApplicationContext applicationContext) {
        log.info("注册中心地址:{}", registryAddress);
        log.info("RPC 服务启动地址:{}", serviceAddress);
        if (StringUtils.isBlank(registryAddress) || StringUtils.isBlank(serviceAddress)) {
            log.warn("RPC 服务启动失败,请检查是否配置了 registry.address 和 service.address");
            return;
        }
        scanRpcServiceBean(applicationContext);
        startRpcServer();
    }

    /**
     * 获取 Spring 中所有带 RpcService 注解的 Bean
     * @param applicationContext ApplicationContext
     */
    private void scanRpcServiceBean(ApplicationContext applicationContext) {
        // 扫描所有带 RpcService 注解的 Bean
        Map<String, Object> serviceBeanMap = applicationContext.getBeansWithAnnotation(RpcService.class);
        if (serviceBeanMap.isEmpty()) {
            log.warn("No service bean found");
        }
        for (Object serviceBean : serviceBeanMap.values()) {
            // 获取 RpcService 注解的 value 值
            String interfaceName = serviceBean.getClass().getAnnotation(RpcService.class).value().getName();
            handlerMap.put(interfaceName, serviceBean);
        }
    }

    private void startRpcServer() {
        EventLoopGroup bossGroup = new NioEventLoopGroup();
        EventLoopGroup workerGroup = new NioEventLoopGroup();
        try {
            ServerBootstrap serverBootstrap = new ServerBootstrap();
            serverBootstrap.group(bossGroup, workerGroup)
                    .channel(NioServerSocketChannel.class)
                    .childHandler(new ChannelInitializer<SocketChannel>() {
                        @Override
                        protected void initChannel(SocketChannel ch) throws Exception {
                            // 服务端 对请求解码 对响应编码
                            ch.pipeline().addLast(new RpcDecoder(RpcRequest.class))
                                    .addLast(new RpcEncoder(RpcResponse.class))
                                    // 服务端处理 RPC 请求的 Handler
                                    .addLast(new RpcServerHandler(handlerMap));
                        }
                    })
                    .option(ChannelOption.SO_BACKLOG, 128)
                    .childOption(ChannelOption.SO_KEEPALIVE, true);
            // 解析服务地址
            String[] array = serviceAddress.split(":");
            String host = array[0];
            int port = Integer.parseInt(array[1]);

            // 启动 RPC 服务
            ChannelFuture channelFuture = serverBootstrap.bind(host, port).sync();
            log.info("RPC 服务器启动成功,监听端口:{}", port);
            // 服务注册
            ServiceRegistry zookeeperServiceRegistry = new ServiceRegistry(registryAddress);
            zookeeperServiceRegistry.register(serviceAddress);
            channelFuture.channel().closeFuture().sync();
        } catch (InterruptedException e) {
            log.error("RpcServer start error", e);
            Thread.currentThread().interrupt();
        } finally {
            bossGroup.shutdownGracefully();
            workerGroup.shutdownGracefully();
        }
    }

}

客户端调用

/**
 * RPC Server / Netty Server
 * @author: houcheng
 * @date: 2022/5/31 16:57:13
 */
@Slf4j
@Component
@PropertySource("classpath:simple-rpc.properties")
public class RpcServer {

    /**
     * 注册中心地址
     */
    @Value("${registry.address:}")
    private String registryAddress;

    /**
     * RPC 服务启动地址
     */
    @Value("${service.address:}")
    private String serviceAddress;

    /**
     * 存放接口名和服务对象(实现类对象)之间的映射关系
     */
    private final Map<String, Object> handlerMap = new HashMap<>();

    /**
     * RpcServer 启动
     * @param applicationContext ApplicationContext
     */
    public void start(ApplicationContext applicationContext) {
        log.info("注册中心地址:{}", registryAddress);
        log.info("RPC 服务启动地址:{}", serviceAddress);
        if (StringUtils.isBlank(registryAddress) || StringUtils.isBlank(serviceAddress)) {
            log.warn("RPC 服务启动失败,请检查是否配置了 registry.address 和 service.address");
            return;
        }
        scanRpcServiceBean(applicationContext);
        startRpcServer();
    }

    /**
     * 获取 Spring 中所有带 RpcService 注解的 Bean
     * @param applicationContext ApplicationContext
     */
    private void scanRpcServiceBean(ApplicationContext applicationContext) {
        // 扫描所有带 RpcService 注解的 Bean
        Map<String, Object> serviceBeanMap = applicationContext.getBeansWithAnnotation(RpcService.class);
        if (serviceBeanMap.isEmpty()) {
            log.warn("No service bean found");
        }
        for (Object serviceBean : serviceBeanMap.values()) {
            // 获取 RpcService 注解的 value 值
            String interfaceName = serviceBean.getClass().getAnnotation(RpcService.class).value().getName();
            handlerMap.put(interfaceName, serviceBean);
        }
    }

    private void startRpcServer() {
        EventLoopGroup bossGroup = new NioEventLoopGroup();
        EventLoopGroup workerGroup = new NioEventLoopGroup();
        try {
            ServerBootstrap serverBootstrap = new ServerBootstrap();
            serverBootstrap.group(bossGroup, workerGroup)
                    .channel(NioServerSocketChannel.class)
                    .childHandler(new ChannelInitializer<SocketChannel>() {
                        @Override
                        protected void initChannel(SocketChannel ch) throws Exception {
                            // 服务端 对请求解码 对响应编码
                            ch.pipeline().addLast(new RpcDecoder(RpcRequest.class))
                                    .addLast(new RpcEncoder(RpcResponse.class))
                                    // 服务端处理 RPC 请求的 Handler
                                    .addLast(new RpcServerHandler(handlerMap));
                        }
                    })
                    .option(ChannelOption.SO_BACKLOG, 128)
                    .childOption(ChannelOption.SO_KEEPALIVE, true);
            // 解析服务地址
            String[] array = serviceAddress.split(":");
            String host = array[0];
            int port = Integer.parseInt(array[1]);

            // 启动 RPC 服务
            ChannelFuture channelFuture = serverBootstrap.bind(host, port).sync();
            log.info("RPC 服务器启动成功,监听端口:{}", port);
            // 服务注册
            ServiceRegistry zookeeperServiceRegistry = new ServiceRegistry(registryAddress);
            zookeeperServiceRegistry.register(serviceAddress);
            channelFuture.channel().closeFuture().sync();
        } catch (InterruptedException e) {
            log.error("RpcServer start error", e);
            Thread.currentThread().interrupt();
        } finally {
            bossGroup.shutdownGracefully();
            workerGroup.shutdownGracefully();
        }
    }

}

如何进行服务注册 / 服务发现

就是现在我们知道怎么发网络请求了,但是不知道发给谁,所以我们需要一个服务表里面存储所有服务信息;

这里我们使用 Zookeeper,可以做到服务上下线自动通知(节点监控),非常方便

服务注册

/**
 * Zookeeper Server Registry
 * @author: houcheng
 * @date: 2022/5/31 16:58:41
 */
@Slf4j
public class ServiceRegistry {

    private final CuratorFramework curatorZkClient;

    public ServiceRegistry(String zkAddress) {
        curatorZkClient = ZookeeperUtil.getCuratorZookeeperClient(zkAddress);
    }

    /**
     * 服务注册
     * @param serviceAddress 服务地址
     */
    public void register(String serviceAddress) {
        log.info("Registering service address {}", serviceAddress);
        byte[] data = serviceAddress.getBytes(StandardCharsets.UTF_8);
        String servicePath = ZookeeperConstant.ZK_SERVICE_PATH_PREFIX;
        try {
            String resultPath = curatorZkClient.create().creatingParentsIfNeeded()
                    .withMode(CreateMode.EPHEMERAL_SEQUENTIAL)
                    .withACL(ZooDefs.Ids.OPEN_ACL_UNSAFE)
                    .forPath(servicePath, data);
            log.info("Service {} registered at path: {}", serviceAddress, resultPath);
        } catch (Exception e) {
            log.error("Registering service address {} failed", serviceAddress, e);
        }
    }
}

服务发现

/**
 * Zookeeper Server Discovery
 * @author :hc
 * @date :Created in 2022/5/31 20:52
 * @modified :
 */
@Slf4j
@Component
@PropertySource("classpath:simple-rpc.properties")
public class ServiceDiscovery {

    /**
     * 服务注册中心地址
     */
    @Value("${registry.address}")
    private String registryAddress;

    /**
     * 服务列表
     */
    private volatile List<String> serviceList = new ArrayList<>();

    @PostConstruct
    public void init() {
        // 注册永久监听
        watchNode();
    }

    /**
     * 返回一个可用的服务地址
     * 持续阻塞直到有可用的服务地址
     * @return 服务地址 ip:port 字符串
     */
    public String discover() {
        int size = serviceList.size();
        String result;
        // 如果服务列表为空,则返回null
        if (size == 0) {
            log.warn("No available service, please check the service registry address or service status ");
            return null;
        }
        // 如果服务列表只有一个服务地址,则直接返回
        if (size == 1) {
            result = serviceList.get(0);
            log.info("Only one service available, return it directly: {}", result);
            return result;
        }
        // 如果服务列表大于 1 个,则随机返回一个服务地址
        result = serviceList.get(ThreadLocalRandom.current().nextInt(size));
        log.info("Randomly select a service: {}", result);
        return result;
    }

    /**
     * 注册永久监听
     */
    private void watchNode() {
        log.info("Start watching the service registry node: {}", ZookeeperConstant.ZK_REGISTRY_PATH);
        log.info("Service registry address: {}", registryAddress);
        CuratorFramework curatorZkClient = ZookeeperUtil.getCuratorZookeeperClient(registryAddress);
        PathChildrenCache pathChildrenCache = new PathChildrenCache(curatorZkClient, ZookeeperConstant.ZK_REGISTRY_PATH, true);
        try {
            // 同步初始化 初始化后即可获取到当前服务列表
            pathChildrenCache.start(PathChildrenCache.StartMode.BUILD_INITIAL_CACHE);
            // 初次加载 获取服务列表
            flushServiceList(pathChildrenCache);
            // 添加永久监听 监听节点变化 并及时刷新服务列表
            pathChildrenCache.getListenable().addListener((client, event) -> {
                log.info("Node change event: {}", event.getType());
                // 监听到子节点变化 刷新服务列表
                flushServiceList(pathChildrenCache);
            });
        } catch (Exception e) {
            log.error("An exception occurred while listening for the change of zookeeper node", e);
        }

    }

    /**
     * 刷新服务列表
     * @param pathChildrenCache PathChildrenCache
     */
    private void flushServiceList(PathChildrenCache pathChildrenCache) {
        List<ChildData> childDataList = pathChildrenCache.getCurrentData();
        ArrayList<String> curServiceList = new ArrayList<>();
        childDataList.forEach(childData -> curServiceList.add(new String(childData.getData())));
        serviceList = curServiceList;
    }
}

自己动手实现一个RPC框架(simple-rpc)
https://www.powercheng.fun/articles/d77fd4b2/
作者
powercheng
发布于
2022年6月25日
许可协议