自己动手实现一个RPC框架(simple-rpc)
使用 Zookeeper + Netty + Spring + Protostuff 实现一个简单的 RPC 框架
概述
由于各服务部署在不同机器,服务间的调用免不了网络通信过程,服务消费方每调用一个服务都要写一坨网络通信相关的代码,不仅复杂而且极易出错。
如果有一种方式能让我们像调用本地服务一样调用远程服务,而让调用者对网络通信这些细节透明,那么将大大提高生产力,比如服务消费方在执行helloWorldService.sayHello(“test”) 时,实质上调用的是远端的服务。这种方式其实就是 RPC(Remote Procedure Call Protocol),在各大互联网公司中被广泛使用,如阿里巴巴的 hsf、dubbo(开源)、Facebook的thrift(开源)、Google grpc(开源)、Twitter 的 finagle(开源)等。
项目地址:https://github.com/hczs/simple-rpc
所用技术栈
- 首先我们需要发网络请求进行调用,这里使用网络框架 Netty
- 然后调用之间的消息需要序列化和反序列化,这里使用序列化框架 protostuff
- 再然后呢我们需要知道服务提供方地址是多少,也就是 服务发现/ 服务注册,我们使用 Zookeeper 来管理服务,使用 Curator 框架来操作 Zookeeper
- 我们使用 Spring 来方便的管理 Bean 进行随意的注入使用,以及配置文件值注入
- 使用 lombok 来精简代码,方便快速开发
- 使用 objenesis 库来优化我们反序列化 请求 / 响应对象的速度
- 使用 cglib 来优化我们接收响应处理执行方法的速度
- 使用 commons.lang3 库中的一些常用工具类
RPC 框架都帮我们干了些什么
要让网络通信细节对使用者透明,我们需要对通信细节进行封装,我们先看下一个RPC调用的流程涉及到哪些通信细节
![image-20220625145426377](/articles/d77fd4b2/f23DcF8d93image-20220625145426377.png)
- 服务消费方(client)调用以本地调用方式调用服务;
- client stub接收到调用后负责将方法、参数等组装成能够进行网络传输的消息体;
- client stub找到服务地址,并将消息发送到服务端;
- server stub收到消息后进行解码;
- server stub根据解码结果调用本地的服务;
- 本地服务执行并将结果返回给server stub;
- server stub将返回结果打包成消息并发送至消费方;
- client stub接收到消息,并进行解码;
- 服务消费方得到最终结果。
RPC的目标就是要2~8这些步骤都封装起来,让用户对这些细节透明。
简单来说:服务端,启动就自动注册服务,等待客户端调用
如何将封装上述2~8的步骤
可以通过动态代理,生成代理对象,然后再 invoke 方法的时候进行网络请求,达到对细节的封装,即下面的 invoke 方法:
/**
* RPC 客户端远程调用具体处理器
* @author: houcheng
* @date: 2022/6/1 15:56:36
*/
@Slf4j
public class RemoteInvocationHandler implements InvocationHandler {
/**
* 服务发现对象
*/
private final ServiceDiscovery serviceDiscovery;
public RemoteInvocationHandler(ServiceDiscovery serviceDiscovery) {
this.serviceDiscovery = serviceDiscovery;
}
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
RpcRequest rpcRequest = new RpcRequest();
rpcRequest.setRequestId(UUID.randomUUID().toString());
rpcRequest.setInterfaceName(method.getDeclaringClass().getName());
rpcRequest.setMethodName(method.getName());
rpcRequest.setParameterTypes(method.getParameterTypes());
rpcRequest.setParameters(args);
// 服务发现 找一个具体能够处理请求的服务地址
String serviceAddress = serviceDiscovery.discover();
if (serviceAddress == null) {
return null;
}
String[] hostAndPort = serviceAddress.split(":");
String host = hostAndPort[0];
int port = Integer.parseInt(hostAndPort[1]);
// 发送请求
RpcClient rpcClient = new RpcClient(host, port);
RpcResponse rpcResponse = rpcClient.sendRequest(rpcRequest);
log.info("rpcResponse: {}", rpcResponse);
return rpcResponse.getResult();
}
}
请求响应的消息数据结构设计
请求消息
- 接口名称:传过去,服务端知道你想要调用哪个接口
- 方法名:接口可能有多个方法,这个也是必须滴
- 参数类型 & 参数值:参数类型有很多,比如有bool、int、long、double、string、map、list,甚至如struct(class);以及相应的参数值;
- 超时时间:不能一直请求阻塞着(这里暂未实现)
- requestID:标识唯一请求id,这样请求和响应才能对得上号,要不然发出去多个请求,返回来多个响应,都分不清了
/**
* RPC 请求对象
* @author: houcheng
* @date: 2022/5/31 16:07:06
*/
@Data
public class RpcRequest {
/**
* 请求id
*/
private String requestId;
/**
* 接口名
*/
private String interfaceName;
/**
* 方法名
*/
private String methodName;
/**
* 参数类型
*/
private Class<?>[] parameterTypes;
/**
* 参数值
*/
private Object[] parameters;
}
响应消息
- 返回值
- 状态码
- requestID
- 异常信息
/**
* RPC 响应对象
* @author: houcheng
* @date: 2022/5/31 16:09:27
*/
@Data
public class RpcResponse {
/**
* 请求ID
*/
private String requestId;
/**
* 响应结果
*/
private Object result;
/**
* 状态码
*/
private Integer code;
/**
* 错误信息
*/
private Throwable error;
}
消息编解码 / 序列化
为什么要序列化?因为序列化后会方便进行网络之间传输数据
序列化(编码):对象转换为二进制数据
反序列化(解码):二进制数据转换为对象
序列化
/**
* 序列化工具类 使用 Protostuff 实现
* @author :hc
* @date :Created in 2022/5/31 19:48
* @modified :
*/
public class SerializationUtil {
private static Map<Class<?>, Schema<?>> cachedSchema = new ConcurrentHashMap<>();
private static Objenesis objenesis = new ObjenesisStd(true);
private SerializationUtil() {
}
@SuppressWarnings("unchecked")
private static <T> Schema<T> getSchema(Class<T> cls) {
return (Schema<T>) cachedSchema.computeIfAbsent(cls, key -> RuntimeSchema.createFrom(cls));
}
@SuppressWarnings("unchecked")
public static <T> byte[] serialize(T obj) {
Class<T> cls = (Class<T>) obj.getClass();
LinkedBuffer buffer = LinkedBuffer.allocate(LinkedBuffer.DEFAULT_BUFFER_SIZE);
try {
Schema<T> schema = getSchema(cls);
return ProtobufIOUtil.toByteArray(obj, schema, buffer);
} catch (Exception e) {
throw new IllegalStateException(e.getMessage(), e);
} finally {
buffer.clear();
}
}
public static <T> T deserialize(byte[] data, Class<T> cls) {
try {
T message = objenesis.newInstance(cls);
Schema<T> schema = getSchema(cls);
ProtobufIOUtil.mergeFrom(data, message, schema);
return message;
} catch (Exception e) {
throw new IllegalStateException(e.getMessage(), e);
}
}
}
编解码处理器
/**
* @author :hc
* @date :Created in 2022/5/31 19:40
* @modified :
*/
public class RpcDecoder extends ByteToMessageDecoder {
private Class<?> genericClass;
public RpcDecoder(Class<?> genericClass) {
this.genericClass = genericClass;
}
@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
if (in.readableBytes() < 4 ) {
return;
}
in.markReaderIndex();
int dataLength = in.readInt();
if (dataLength < 0) {
ctx.close();
}
if (in.readableBytes() < dataLength) {
in.resetReaderIndex();
return;
}
byte[] data = new byte[dataLength];
in.readBytes(data);
Object obj = SerializationUtil.deserialize(data, genericClass);
out.add(obj);
}
}
/**
* @author :hc
* @date :Created in 2022/5/31 19:40
* @modified :
*/
public class RpcEncoder extends MessageToByteEncoder {
private Class<?> genericClass;
public RpcEncoder(Class<?> genericClass) {
this.genericClass = genericClass;
}
@Override
public void encode(ChannelHandlerContext ctx, Object in, ByteBuf out) {
if (genericClass.isInstance(in)) {
byte[] data = SerializationUtil.serialize(in);
out.writeInt(data.length);
out.writeBytes(data);
}
}
}
通信
现在消息有了,该怎么发出去呢?就涉及到网络通信了,这里使用的是网络框架 Netty
服务端启动 Netty Server
/**
* RPC Server / Netty Server
* @author: houcheng
* @date: 2022/5/31 16:57:13
*/
@Slf4j
@Component
@PropertySource("classpath:simple-rpc.properties")
public class RpcServer {
/**
* 注册中心地址
*/
@Value("${registry.address:}")
private String registryAddress;
/**
* RPC 服务启动地址
*/
@Value("${service.address:}")
private String serviceAddress;
/**
* 存放接口名和服务对象(实现类对象)之间的映射关系
*/
private final Map<String, Object> handlerMap = new HashMap<>();
/**
* RpcServer 启动
* @param applicationContext ApplicationContext
*/
public void start(ApplicationContext applicationContext) {
log.info("注册中心地址:{}", registryAddress);
log.info("RPC 服务启动地址:{}", serviceAddress);
if (StringUtils.isBlank(registryAddress) || StringUtils.isBlank(serviceAddress)) {
log.warn("RPC 服务启动失败,请检查是否配置了 registry.address 和 service.address");
return;
}
scanRpcServiceBean(applicationContext);
startRpcServer();
}
/**
* 获取 Spring 中所有带 RpcService 注解的 Bean
* @param applicationContext ApplicationContext
*/
private void scanRpcServiceBean(ApplicationContext applicationContext) {
// 扫描所有带 RpcService 注解的 Bean
Map<String, Object> serviceBeanMap = applicationContext.getBeansWithAnnotation(RpcService.class);
if (serviceBeanMap.isEmpty()) {
log.warn("No service bean found");
}
for (Object serviceBean : serviceBeanMap.values()) {
// 获取 RpcService 注解的 value 值
String interfaceName = serviceBean.getClass().getAnnotation(RpcService.class).value().getName();
handlerMap.put(interfaceName, serviceBean);
}
}
private void startRpcServer() {
EventLoopGroup bossGroup = new NioEventLoopGroup();
EventLoopGroup workerGroup = new NioEventLoopGroup();
try {
ServerBootstrap serverBootstrap = new ServerBootstrap();
serverBootstrap.group(bossGroup, workerGroup)
.channel(NioServerSocketChannel.class)
.childHandler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(SocketChannel ch) throws Exception {
// 服务端 对请求解码 对响应编码
ch.pipeline().addLast(new RpcDecoder(RpcRequest.class))
.addLast(new RpcEncoder(RpcResponse.class))
// 服务端处理 RPC 请求的 Handler
.addLast(new RpcServerHandler(handlerMap));
}
})
.option(ChannelOption.SO_BACKLOG, 128)
.childOption(ChannelOption.SO_KEEPALIVE, true);
// 解析服务地址
String[] array = serviceAddress.split(":");
String host = array[0];
int port = Integer.parseInt(array[1]);
// 启动 RPC 服务
ChannelFuture channelFuture = serverBootstrap.bind(host, port).sync();
log.info("RPC 服务器启动成功,监听端口:{}", port);
// 服务注册
ServiceRegistry zookeeperServiceRegistry = new ServiceRegistry(registryAddress);
zookeeperServiceRegistry.register(serviceAddress);
channelFuture.channel().closeFuture().sync();
} catch (InterruptedException e) {
log.error("RpcServer start error", e);
Thread.currentThread().interrupt();
} finally {
bossGroup.shutdownGracefully();
workerGroup.shutdownGracefully();
}
}
}
客户端调用
/**
* RPC Server / Netty Server
* @author: houcheng
* @date: 2022/5/31 16:57:13
*/
@Slf4j
@Component
@PropertySource("classpath:simple-rpc.properties")
public class RpcServer {
/**
* 注册中心地址
*/
@Value("${registry.address:}")
private String registryAddress;
/**
* RPC 服务启动地址
*/
@Value("${service.address:}")
private String serviceAddress;
/**
* 存放接口名和服务对象(实现类对象)之间的映射关系
*/
private final Map<String, Object> handlerMap = new HashMap<>();
/**
* RpcServer 启动
* @param applicationContext ApplicationContext
*/
public void start(ApplicationContext applicationContext) {
log.info("注册中心地址:{}", registryAddress);
log.info("RPC 服务启动地址:{}", serviceAddress);
if (StringUtils.isBlank(registryAddress) || StringUtils.isBlank(serviceAddress)) {
log.warn("RPC 服务启动失败,请检查是否配置了 registry.address 和 service.address");
return;
}
scanRpcServiceBean(applicationContext);
startRpcServer();
}
/**
* 获取 Spring 中所有带 RpcService 注解的 Bean
* @param applicationContext ApplicationContext
*/
private void scanRpcServiceBean(ApplicationContext applicationContext) {
// 扫描所有带 RpcService 注解的 Bean
Map<String, Object> serviceBeanMap = applicationContext.getBeansWithAnnotation(RpcService.class);
if (serviceBeanMap.isEmpty()) {
log.warn("No service bean found");
}
for (Object serviceBean : serviceBeanMap.values()) {
// 获取 RpcService 注解的 value 值
String interfaceName = serviceBean.getClass().getAnnotation(RpcService.class).value().getName();
handlerMap.put(interfaceName, serviceBean);
}
}
private void startRpcServer() {
EventLoopGroup bossGroup = new NioEventLoopGroup();
EventLoopGroup workerGroup = new NioEventLoopGroup();
try {
ServerBootstrap serverBootstrap = new ServerBootstrap();
serverBootstrap.group(bossGroup, workerGroup)
.channel(NioServerSocketChannel.class)
.childHandler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(SocketChannel ch) throws Exception {
// 服务端 对请求解码 对响应编码
ch.pipeline().addLast(new RpcDecoder(RpcRequest.class))
.addLast(new RpcEncoder(RpcResponse.class))
// 服务端处理 RPC 请求的 Handler
.addLast(new RpcServerHandler(handlerMap));
}
})
.option(ChannelOption.SO_BACKLOG, 128)
.childOption(ChannelOption.SO_KEEPALIVE, true);
// 解析服务地址
String[] array = serviceAddress.split(":");
String host = array[0];
int port = Integer.parseInt(array[1]);
// 启动 RPC 服务
ChannelFuture channelFuture = serverBootstrap.bind(host, port).sync();
log.info("RPC 服务器启动成功,监听端口:{}", port);
// 服务注册
ServiceRegistry zookeeperServiceRegistry = new ServiceRegistry(registryAddress);
zookeeperServiceRegistry.register(serviceAddress);
channelFuture.channel().closeFuture().sync();
} catch (InterruptedException e) {
log.error("RpcServer start error", e);
Thread.currentThread().interrupt();
} finally {
bossGroup.shutdownGracefully();
workerGroup.shutdownGracefully();
}
}
}
如何进行服务注册 / 服务发现
就是现在我们知道怎么发网络请求了,但是不知道发给谁,所以我们需要一个服务表里面存储所有服务信息;
这里我们使用 Zookeeper,可以做到服务上下线自动通知(节点监控),非常方便
服务注册
/**
* Zookeeper Server Registry
* @author: houcheng
* @date: 2022/5/31 16:58:41
*/
@Slf4j
public class ServiceRegistry {
private final CuratorFramework curatorZkClient;
public ServiceRegistry(String zkAddress) {
curatorZkClient = ZookeeperUtil.getCuratorZookeeperClient(zkAddress);
}
/**
* 服务注册
* @param serviceAddress 服务地址
*/
public void register(String serviceAddress) {
log.info("Registering service address {}", serviceAddress);
byte[] data = serviceAddress.getBytes(StandardCharsets.UTF_8);
String servicePath = ZookeeperConstant.ZK_SERVICE_PATH_PREFIX;
try {
String resultPath = curatorZkClient.create().creatingParentsIfNeeded()
.withMode(CreateMode.EPHEMERAL_SEQUENTIAL)
.withACL(ZooDefs.Ids.OPEN_ACL_UNSAFE)
.forPath(servicePath, data);
log.info("Service {} registered at path: {}", serviceAddress, resultPath);
} catch (Exception e) {
log.error("Registering service address {} failed", serviceAddress, e);
}
}
}
服务发现
/**
* Zookeeper Server Discovery
* @author :hc
* @date :Created in 2022/5/31 20:52
* @modified :
*/
@Slf4j
@Component
@PropertySource("classpath:simple-rpc.properties")
public class ServiceDiscovery {
/**
* 服务注册中心地址
*/
@Value("${registry.address}")
private String registryAddress;
/**
* 服务列表
*/
private volatile List<String> serviceList = new ArrayList<>();
@PostConstruct
public void init() {
// 注册永久监听
watchNode();
}
/**
* 返回一个可用的服务地址
* 持续阻塞直到有可用的服务地址
* @return 服务地址 ip:port 字符串
*/
public String discover() {
int size = serviceList.size();
String result;
// 如果服务列表为空,则返回null
if (size == 0) {
log.warn("No available service, please check the service registry address or service status ");
return null;
}
// 如果服务列表只有一个服务地址,则直接返回
if (size == 1) {
result = serviceList.get(0);
log.info("Only one service available, return it directly: {}", result);
return result;
}
// 如果服务列表大于 1 个,则随机返回一个服务地址
result = serviceList.get(ThreadLocalRandom.current().nextInt(size));
log.info("Randomly select a service: {}", result);
return result;
}
/**
* 注册永久监听
*/
private void watchNode() {
log.info("Start watching the service registry node: {}", ZookeeperConstant.ZK_REGISTRY_PATH);
log.info("Service registry address: {}", registryAddress);
CuratorFramework curatorZkClient = ZookeeperUtil.getCuratorZookeeperClient(registryAddress);
PathChildrenCache pathChildrenCache = new PathChildrenCache(curatorZkClient, ZookeeperConstant.ZK_REGISTRY_PATH, true);
try {
// 同步初始化 初始化后即可获取到当前服务列表
pathChildrenCache.start(PathChildrenCache.StartMode.BUILD_INITIAL_CACHE);
// 初次加载 获取服务列表
flushServiceList(pathChildrenCache);
// 添加永久监听 监听节点变化 并及时刷新服务列表
pathChildrenCache.getListenable().addListener((client, event) -> {
log.info("Node change event: {}", event.getType());
// 监听到子节点变化 刷新服务列表
flushServiceList(pathChildrenCache);
});
} catch (Exception e) {
log.error("An exception occurred while listening for the change of zookeeper node", e);
}
}
/**
* 刷新服务列表
* @param pathChildrenCache PathChildrenCache
*/
private void flushServiceList(PathChildrenCache pathChildrenCache) {
List<ChildData> childDataList = pathChildrenCache.getCurrentData();
ArrayList<String> curServiceList = new ArrayList<>();
childDataList.forEach(childData -> curServiceList.add(new String(childData.getData())));
serviceList = curServiceList;
}
}