自己动手实现一个RPC框架(simple-rpc)

使用 Zookeeper + Netty + Spring + Protostuff 实现一个简单的 RPC 框架

概述

由于各服务部署在不同机器,服务间的调用免不了网络通信过程,服务消费方每调用一个服务都要写一坨网络通信相关的代码,不仅复杂而且极易出错。

如果有一种方式能让我们像调用本地服务一样调用远程服务,而让调用者对网络通信这些细节透明,那么将大大提高生产力,比如服务消费方在执行helloWorldService.sayHello(“test”) 时,实质上调用的是远端的服务。这种方式其实就是 RPC(Remote Procedure Call Protocol),在各大互联网公司中被广泛使用,如阿里巴巴的 hsf、dubbo(开源)、Facebook的thrift(开源)、Google grpc(开源)、Twitter 的 finagle(开源)等。

项目地址:https://github.com/hczs/simple-rpc

所用技术栈

  1. 首先我们需要发网络请求进行调用,这里使用网络框架 Netty
  2. 然后调用之间的消息需要序列化和反序列化,这里使用序列化框架 protostuff
  3. 再然后呢我们需要知道服务提供方地址是多少,也就是 服务发现/ 服务注册,我们使用 Zookeeper 来管理服务,使用 Curator 框架来操作 Zookeeper
  4. 我们使用 Spring 来方便的管理 Bean 进行随意的注入使用,以及配置文件值注入
  5. 使用 lombok 来精简代码,方便快速开发
  6. 使用 objenesis 库来优化我们反序列化 请求 / 响应对象的速度
  7. 使用 cglib 来优化我们接收响应处理执行方法的速度
  8. 使用 commons.lang3 库中的一些常用工具类

RPC 框架都帮我们干了些什么

要让网络通信细节对使用者透明,我们需要对通信细节进行封装,我们先看下一个RPC调用的流程涉及到哪些通信细节

  1. 服务消费方(client)调用以本地调用方式调用服务;
  2. client stub接收到调用后负责将方法、参数等组装成能够进行网络传输的消息体;
  3. client stub找到服务地址,并将消息发送到服务端;
  4. server stub收到消息后进行解码;
  5. server stub根据解码结果调用本地的服务;
  6. 本地服务执行并将结果返回给server stub;
  7. server stub将返回结果打包成消息并发送至消费方;
  8. client stub接收到消息,并进行解码;
  9. 服务消费方得到最终结果。

RPC的目标就是要2~8这些步骤都封装起来,让用户对这些细节透明。

简单来说:服务端,启动就自动注册服务,等待客户端调用

如何将封装上述2~8的步骤

可以通过动态代理,生成代理对象,然后再 invoke 方法的时候进行网络请求,达到对细节的封装,即下面的 invoke 方法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
/**
* RPC 客户端远程调用具体处理器
* @author: houcheng
* @date: 2022/6/1 15:56:36
*/
@Slf4j
public class RemoteInvocationHandler implements InvocationHandler {

/**
* 服务发现对象
*/
private final ServiceDiscovery serviceDiscovery;

public RemoteInvocationHandler(ServiceDiscovery serviceDiscovery) {
this.serviceDiscovery = serviceDiscovery;
}

@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
RpcRequest rpcRequest = new RpcRequest();
rpcRequest.setRequestId(UUID.randomUUID().toString());
rpcRequest.setInterfaceName(method.getDeclaringClass().getName());
rpcRequest.setMethodName(method.getName());
rpcRequest.setParameterTypes(method.getParameterTypes());
rpcRequest.setParameters(args);
// 服务发现 找一个具体能够处理请求的服务地址
String serviceAddress = serviceDiscovery.discover();
if (serviceAddress == null) {
return null;
}
String[] hostAndPort = serviceAddress.split(":");
String host = hostAndPort[0];
int port = Integer.parseInt(hostAndPort[1]);
// 发送请求
RpcClient rpcClient = new RpcClient(host, port);
RpcResponse rpcResponse = rpcClient.sendRequest(rpcRequest);
log.info("rpcResponse: {}", rpcResponse);
return rpcResponse.getResult();
}
}

请求响应的消息数据结构设计

请求消息

  1. 接口名称:传过去,服务端知道你想要调用哪个接口
  2. 方法名:接口可能有多个方法,这个也是必须滴
  3. 参数类型 & 参数值:参数类型有很多,比如有bool、int、long、double、string、map、list,甚至如struct(class);以及相应的参数值;
  4. 超时时间:不能一直请求阻塞着(这里暂未实现)
  5. requestID:标识唯一请求id,这样请求和响应才能对得上号,要不然发出去多个请求,返回来多个响应,都分不清了
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
/**
* RPC 请求对象
* @author: houcheng
* @date: 2022/5/31 16:07:06
*/
@Data
public class RpcRequest {

/**
* 请求id
*/
private String requestId;

/**
* 接口名
*/
private String interfaceName;

/**
* 方法名
*/
private String methodName;

/**
* 参数类型
*/
private Class<?>[] parameterTypes;

/**
* 参数值
*/
private Object[] parameters;
}

响应消息

  1. 返回值
  2. 状态码
  3. requestID
  4. 异常信息
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
/**
* RPC 响应对象
* @author: houcheng
* @date: 2022/5/31 16:09:27
*/
@Data
public class RpcResponse {

/**
* 请求ID
*/
private String requestId;

/**
* 响应结果
*/
private Object result;

/**
* 状态码
*/
private Integer code;

/**
* 错误信息
*/
private Throwable error;
}

消息编解码 / 序列化

为什么要序列化?因为序列化后会方便进行网络之间传输数据

序列化(编码):对象转换为二进制数据

反序列化(解码):二进制数据转换为对象

序列化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
/**
* 序列化工具类 使用 Protostuff 实现
* @author :hc
* @date :Created in 2022/5/31 19:48
* @modified
*/
public class SerializationUtil {

private static Map<Class<?>, Schema<?>> cachedSchema = new ConcurrentHashMap<>();

private static Objenesis objenesis = new ObjenesisStd(true);

private SerializationUtil() {

}

@SuppressWarnings("unchecked")
private static <T> Schema<T> getSchema(Class<T> cls) {
return (Schema<T>) cachedSchema.computeIfAbsent(cls, key -> RuntimeSchema.createFrom(cls));
}

@SuppressWarnings("unchecked")
public static <T> byte[] serialize(T obj) {
Class<T> cls = (Class<T>) obj.getClass();
LinkedBuffer buffer = LinkedBuffer.allocate(LinkedBuffer.DEFAULT_BUFFER_SIZE);
try {
Schema<T> schema = getSchema(cls);
return ProtobufIOUtil.toByteArray(obj, schema, buffer);
} catch (Exception e) {
throw new IllegalStateException(e.getMessage(), e);
} finally {
buffer.clear();
}
}

public static <T> T deserialize(byte[] data, Class<T> cls) {
try {
T message = objenesis.newInstance(cls);
Schema<T> schema = getSchema(cls);
ProtobufIOUtil.mergeFrom(data, message, schema);
return message;
} catch (Exception e) {
throw new IllegalStateException(e.getMessage(), e);
}
}
}

编解码处理器

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
/**
* @author :hc
* @date :Created in 2022/5/31 19:40
* @modified
*/
public class RpcDecoder extends ByteToMessageDecoder {

private Class<?> genericClass;

public RpcDecoder(Class<?> genericClass) {
this.genericClass = genericClass;
}

@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
if (in.readableBytes() < 4 ) {
return;
}

in.markReaderIndex();
int dataLength = in.readInt();

if (dataLength < 0) {
ctx.close();
}

if (in.readableBytes() < dataLength) {
in.resetReaderIndex();
return;
}

byte[] data = new byte[dataLength];
in.readBytes(data);

Object obj = SerializationUtil.deserialize(data, genericClass);
out.add(obj);
}
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
/**
* @author :hc
* @date :Created in 2022/5/31 19:40
* @modified
*/
public class RpcEncoder extends MessageToByteEncoder {

private Class<?> genericClass;

public RpcEncoder(Class<?> genericClass) {
this.genericClass = genericClass;
}

@Override
public void encode(ChannelHandlerContext ctx, Object in, ByteBuf out) {
if (genericClass.isInstance(in)) {
byte[] data = SerializationUtil.serialize(in);
out.writeInt(data.length);
out.writeBytes(data);
}
}
}

通信

现在消息有了,该怎么发出去呢?就涉及到网络通信了,这里使用的是网络框架 Netty

服务端启动 Netty Server

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
/**
* RPC Server / Netty Server
* @author: houcheng
* @date: 2022/5/31 16:57:13
*/
@Slf4j
@Component
@PropertySource("classpath:simple-rpc.properties")
public class RpcServer {

/**
* 注册中心地址
*/
@Value("${registry.address:}")
private String registryAddress;

/**
* RPC 服务启动地址
*/
@Value("${service.address:}")
private String serviceAddress;

/**
* 存放接口名和服务对象(实现类对象)之间的映射关系
*/
private final Map<String, Object> handlerMap = new HashMap<>();

/**
* RpcServer 启动
* @param applicationContext ApplicationContext
*/
public void start(ApplicationContext applicationContext) {
log.info("注册中心地址:{}", registryAddress);
log.info("RPC 服务启动地址:{}", serviceAddress);
if (StringUtils.isBlank(registryAddress) || StringUtils.isBlank(serviceAddress)) {
log.warn("RPC 服务启动失败,请检查是否配置了 registry.address 和 service.address");
return;
}
scanRpcServiceBean(applicationContext);
startRpcServer();
}

/**
* 获取 Spring 中所有带 RpcService 注解的 Bean
* @param applicationContext ApplicationContext
*/
private void scanRpcServiceBean(ApplicationContext applicationContext) {
// 扫描所有带 RpcService 注解的 Bean
Map<String, Object> serviceBeanMap = applicationContext.getBeansWithAnnotation(RpcService.class);
if (serviceBeanMap.isEmpty()) {
log.warn("No service bean found");
}
for (Object serviceBean : serviceBeanMap.values()) {
// 获取 RpcService 注解的 value 值
String interfaceName = serviceBean.getClass().getAnnotation(RpcService.class).value().getName();
handlerMap.put(interfaceName, serviceBean);
}
}

private void startRpcServer() {
EventLoopGroup bossGroup = new NioEventLoopGroup();
EventLoopGroup workerGroup = new NioEventLoopGroup();
try {
ServerBootstrap serverBootstrap = new ServerBootstrap();
serverBootstrap.group(bossGroup, workerGroup)
.channel(NioServerSocketChannel.class)
.childHandler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(SocketChannel ch) throws Exception {
// 服务端 对请求解码 对响应编码
ch.pipeline().addLast(new RpcDecoder(RpcRequest.class))
.addLast(new RpcEncoder(RpcResponse.class))
// 服务端处理 RPC 请求的 Handler
.addLast(new RpcServerHandler(handlerMap));
}
})
.option(ChannelOption.SO_BACKLOG, 128)
.childOption(ChannelOption.SO_KEEPALIVE, true);
// 解析服务地址
String[] array = serviceAddress.split(":");
String host = array[0];
int port = Integer.parseInt(array[1]);

// 启动 RPC 服务
ChannelFuture channelFuture = serverBootstrap.bind(host, port).sync();
log.info("RPC 服务器启动成功,监听端口:{}", port);
// 服务注册
ServiceRegistry zookeeperServiceRegistry = new ServiceRegistry(registryAddress);
zookeeperServiceRegistry.register(serviceAddress);
channelFuture.channel().closeFuture().sync();
} catch (InterruptedException e) {
log.error("RpcServer start error", e);
Thread.currentThread().interrupt();
} finally {
bossGroup.shutdownGracefully();
workerGroup.shutdownGracefully();
}
}

}

客户端调用

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
/**
* RPC Server / Netty Server
* @author: houcheng
* @date: 2022/5/31 16:57:13
*/
@Slf4j
@Component
@PropertySource("classpath:simple-rpc.properties")
public class RpcServer {

/**
* 注册中心地址
*/
@Value("${registry.address:}")
private String registryAddress;

/**
* RPC 服务启动地址
*/
@Value("${service.address:}")
private String serviceAddress;

/**
* 存放接口名和服务对象(实现类对象)之间的映射关系
*/
private final Map<String, Object> handlerMap = new HashMap<>();

/**
* RpcServer 启动
* @param applicationContext ApplicationContext
*/
public void start(ApplicationContext applicationContext) {
log.info("注册中心地址:{}", registryAddress);
log.info("RPC 服务启动地址:{}", serviceAddress);
if (StringUtils.isBlank(registryAddress) || StringUtils.isBlank(serviceAddress)) {
log.warn("RPC 服务启动失败,请检查是否配置了 registry.address 和 service.address");
return;
}
scanRpcServiceBean(applicationContext);
startRpcServer();
}

/**
* 获取 Spring 中所有带 RpcService 注解的 Bean
* @param applicationContext ApplicationContext
*/
private void scanRpcServiceBean(ApplicationContext applicationContext) {
// 扫描所有带 RpcService 注解的 Bean
Map<String, Object> serviceBeanMap = applicationContext.getBeansWithAnnotation(RpcService.class);
if (serviceBeanMap.isEmpty()) {
log.warn("No service bean found");
}
for (Object serviceBean : serviceBeanMap.values()) {
// 获取 RpcService 注解的 value 值
String interfaceName = serviceBean.getClass().getAnnotation(RpcService.class).value().getName();
handlerMap.put(interfaceName, serviceBean);
}
}

private void startRpcServer() {
EventLoopGroup bossGroup = new NioEventLoopGroup();
EventLoopGroup workerGroup = new NioEventLoopGroup();
try {
ServerBootstrap serverBootstrap = new ServerBootstrap();
serverBootstrap.group(bossGroup, workerGroup)
.channel(NioServerSocketChannel.class)
.childHandler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(SocketChannel ch) throws Exception {
// 服务端 对请求解码 对响应编码
ch.pipeline().addLast(new RpcDecoder(RpcRequest.class))
.addLast(new RpcEncoder(RpcResponse.class))
// 服务端处理 RPC 请求的 Handler
.addLast(new RpcServerHandler(handlerMap));
}
})
.option(ChannelOption.SO_BACKLOG, 128)
.childOption(ChannelOption.SO_KEEPALIVE, true);
// 解析服务地址
String[] array = serviceAddress.split(":");
String host = array[0];
int port = Integer.parseInt(array[1]);

// 启动 RPC 服务
ChannelFuture channelFuture = serverBootstrap.bind(host, port).sync();
log.info("RPC 服务器启动成功,监听端口:{}", port);
// 服务注册
ServiceRegistry zookeeperServiceRegistry = new ServiceRegistry(registryAddress);
zookeeperServiceRegistry.register(serviceAddress);
channelFuture.channel().closeFuture().sync();
} catch (InterruptedException e) {
log.error("RpcServer start error", e);
Thread.currentThread().interrupt();
} finally {
bossGroup.shutdownGracefully();
workerGroup.shutdownGracefully();
}
}

}

如何进行服务注册 / 服务发现

就是现在我们知道怎么发网络请求了,但是不知道发给谁,所以我们需要一个服务表里面存储所有服务信息;

这里我们使用 Zookeeper,可以做到服务上下线自动通知(节点监控),非常方便

服务注册

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
/**
* Zookeeper Server Registry
* @author: houcheng
* @date: 2022/5/31 16:58:41
*/
@Slf4j
public class ServiceRegistry {

private final CuratorFramework curatorZkClient;

public ServiceRegistry(String zkAddress) {
curatorZkClient = ZookeeperUtil.getCuratorZookeeperClient(zkAddress);
}

/**
* 服务注册
* @param serviceAddress 服务地址
*/
public void register(String serviceAddress) {
log.info("Registering service address {}", serviceAddress);
byte[] data = serviceAddress.getBytes(StandardCharsets.UTF_8);
String servicePath = ZookeeperConstant.ZK_SERVICE_PATH_PREFIX;
try {
String resultPath = curatorZkClient.create().creatingParentsIfNeeded()
.withMode(CreateMode.EPHEMERAL_SEQUENTIAL)
.withACL(ZooDefs.Ids.OPEN_ACL_UNSAFE)
.forPath(servicePath, data);
log.info("Service {} registered at path: {}", serviceAddress, resultPath);
} catch (Exception e) {
log.error("Registering service address {} failed", serviceAddress, e);
}
}
}

服务发现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
/**
* Zookeeper Server Discovery
* @author :hc
* @date :Created in 2022/5/31 20:52
* @modified
*/
@Slf4j
@Component
@PropertySource("classpath:simple-rpc.properties")
public class ServiceDiscovery {

/**
* 服务注册中心地址
*/
@Value("${registry.address}")
private String registryAddress;

/**
* 服务列表
*/
private volatile List<String> serviceList = new ArrayList<>();

@PostConstruct
public void init() {
// 注册永久监听
watchNode();
}

/**
* 返回一个可用的服务地址
* 持续阻塞直到有可用的服务地址
* @return 服务地址 ip:port 字符串
*/
public String discover() {
int size = serviceList.size();
String result;
// 如果服务列表为空,则返回null
if (size == 0) {
log.warn("No available service, please check the service registry address or service status ");
return null;
}
// 如果服务列表只有一个服务地址,则直接返回
if (size == 1) {
result = serviceList.get(0);
log.info("Only one service available, return it directly: {}", result);
return result;
}
// 如果服务列表大于 1 个,则随机返回一个服务地址
result = serviceList.get(ThreadLocalRandom.current().nextInt(size));
log.info("Randomly select a service: {}", result);
return result;
}

/**
* 注册永久监听
*/
private void watchNode() {
log.info("Start watching the service registry node: {}", ZookeeperConstant.ZK_REGISTRY_PATH);
log.info("Service registry address: {}", registryAddress);
CuratorFramework curatorZkClient = ZookeeperUtil.getCuratorZookeeperClient(registryAddress);
PathChildrenCache pathChildrenCache = new PathChildrenCache(curatorZkClient, ZookeeperConstant.ZK_REGISTRY_PATH, true);
try {
// 同步初始化 初始化后即可获取到当前服务列表
pathChildrenCache.start(PathChildrenCache.StartMode.BUILD_INITIAL_CACHE);
// 初次加载 获取服务列表
flushServiceList(pathChildrenCache);
// 添加永久监听 监听节点变化 并及时刷新服务列表
pathChildrenCache.getListenable().addListener((client, event) -> {
log.info("Node change event: {}", event.getType());
// 监听到子节点变化 刷新服务列表
flushServiceList(pathChildrenCache);
});
} catch (Exception e) {
log.error("An exception occurred while listening for the change of zookeeper node", e);
}

}

/**
* 刷新服务列表
* @param pathChildrenCache PathChildrenCache
*/
private void flushServiceList(PathChildrenCache pathChildrenCache) {
List<ChildData> childDataList = pathChildrenCache.getCurrentData();
ArrayList<String> curServiceList = new ArrayList<>();
childDataList.forEach(childData -> curServiceList.add(new String(childData.getData())));
serviceList = curServiceList;
}
}