Based on SpringBoot, Handwritten Simple RPC Framework (Three)
Continuing from the previous chapter, after implementing server registration and invocation, the next step is to implement client functionality, which mainly includes load balancing, rate limiting, request sending, and service discovery. Next, we will implement the following features in the order of an RPC call process.
A single request:
Before implementing the client, it’s first necessary to consider what needs to be sent in a single request.
First, the current service name method name, as well as the corresponding parameters and parameter types, are needed, otherwise the server cannot perform the corresponding reflection call based on the request.
Second, the request should include the parameters within @RpcConsumer so that the server can locate the correct service.
Finally, the request should include a unique value for this request to facilitate traceability.
At this point, the basic parameters required for a request have been completed.
/** * generate service name,use to distinguish different service,and * can be * split to get the service name */ public String fetchRpcServiceName() { returnthis.getProject() + "*" + this.getGroup() + "*" + this.getServiceName() + "*" + this.getVersion(); }
}
Service Proxy
First step, during the Spring startup process, scan all classes annotated with @RpcConsumer to generate proxies. Subsequent calls to methods in these classes will invoke the proxy methods, which then initiate requests.
/** * proxy and injection of consumers * * @param bean * @param beanName * @return * @throws BeansException */ @Override public Object postProcessAfterInitialization(Object bean, String beanName)throws BeansException { Class<?> toBeProcessedBean = bean.getClass(); Field[] declaredFields = toBeProcessedBean.getDeclaredFields(); for (Field declaredField : declaredFields) { if (declaredField.isAnnotationPresent(RpcConsumer.class)) { RpcConsumerannotation= declaredField.getAnnotation(RpcConsumer.class); // build rpc service config RpcServiceConfigserviceConfig= RpcServiceConfig.builder() .project(annotation.project()) .version(annotation.version()) .group(annotation.group()) .build(); // create the proxy bean Factory and the proxy bean RpcServiceProxyproxy=newRpcServiceProxy(sendingServiceAdapter, serviceConfig); ObjectrpcProxy= proxy.getProxy(declaredField.getType()); declaredField.setAccessible(true); try { LogUtil.info("create service proxy: {}", bean); declaredField.set(bean, rpcProxy); } catch (IllegalAccessException e) { e.printStackTrace(); } } } return bean; } }
Next, in the invoke method of the proxy class, implement the assembly and invocation of the request. Simultaneously, retrieve the response value from the Future and return it to the caller.
/** * get the proxy object */ @SuppressWarnings("unchecked") public <T> T getProxy(Class<T> clazz) { return (T)Proxy.newProxyInstance(clazz.getClassLoader(), newClass<?>[] {clazz}, this); }
The core method of the client is to send requests, and there are multiple methods for sending requests. Here, only the implementation based on Netty’s Nio is demonstrated. Below is a complete sequence.
First, implement the send method, which should include the functionality of finding addresses and sending requests.
/** * Waiting process request queue */ privatefinal WaitingProcessRequestQueue waitingProcessRequestQueue;
publicRpcSendingServiceAdapterImpl() { this.findingAdapter = ExtensionLoader.getExtensionLoader(RpcServiceFindingAdapter.class) .getExtension(ServiceDiscoveryEnum.ZK.getName()); this.addressChannelManager = SingletonFactory.getInstance(AddressChannelManager.class); this.waitingProcessRequestQueue = SingletonFactory.getInstance(WaitingProcessRequestQueue.class); // initialize eventLoopGroup = newNioEventLoopGroup(); bootstrap = newBootstrap(); bootstrap.group(eventLoopGroup) .channel(NioSocketChannel.class) .handler(newLoggingHandler(LogLevel.INFO)) // The timeout period for the connection. // If this time is exceeded or if the connection cannot be // established, the connection fails. .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 5000) .handler(newChannelInitializer<SocketChannel>() { @Override protectedvoidinitChannel(SocketChannel ch) { ChannelPipelinep= ch.pipeline(); // If no data is sent to the server within 15 seconds, a // heartbeat request is sent p.addLast(newIdleStateHandler(0, 5, 0, TimeUnit.SECONDS)); p.addLast(newRpcMessageEncoder()); p.addLast(newRpcMessageDecoder()); p.addLast(newNettyRpcClientHandler()); } }); }
private Channel fetchAndConnectChannel(InetSocketAddress address) { Channelchannel= addressChannelManager.get(address); if (channel == null) { // connect to service to get new address and rebuild the channel channel = connect(address); addressChannelManager.set(address, channel); } return channel; }
private Channel connect(InetSocketAddress address) { CompletableFuture<Channel> completableFuture = newCompletableFuture<>(); bootstrap.connect(address).addListener((ChannelFutureListener)future -> { if (future.isSuccess()) { // set channel to future LogUtil.info("The client has connected [{}] successful!", address.toString()); completableFuture.complete(future.channel()); } else { LogUtil.error("The client failed to connect to the server [{}],future", address.toString(), future); thrownewIllegalStateException(); } }); Channelchannel=null; try { channel = completableFuture.get(); } catch (Exception e) { LogUtil.error("occur exception when connect to server:", e); } return channel; }
In this class, the core method is sendRpcRequest, which is responsible for obtaining services, creating connections, creating a Future task, and sending requests.
Discover services
The process of discovering services can include:
1. Pulling the service address list from the registry
2. Obtaining the specific type of service through a load balancing algorithm.
Get the address
First, implement the first step (here, caching can be used for further optimization; in this project, zk uses a ConcurrentHashMap to replace caching, detailed code can be seen in CuratorClient):
Consistent hashing algorithm is an algorithm used for data sharding and load balancing in distributed systems. It introduces the concepts of virtual nodes and hash rings to minimize the need for data migration when nodes are dynamically scaled up or down, improving system stability and performance. It is widely used in scenarios such as distributed caching, load balancing, etc.
Implementation:
Hash value calculation
First, according to the consistent hashing algorithm, we need to generate hash values based on the corresponding services. In the following implementation, the input is first passed through the SHA-256 algorithm to produce a 32-byte (256-bit) hash value.
However, such a hash value is too long and not convenient to handle, so we need to shorten it. At the same time, mapping multiple hash values to a node can improve the distribution uniformity of the consistent hashing algorithm, because each node will have multiple hash values in the hash space, which can help reduce the impact of hash space redistribution caused by the addition or removal of nodes.
The calculateHash function will take 8 bytes from the start point j of the obtained 256-bit hash value to generate a new Long-type hash value.
protectedstatic Long calculateHash(byte[] digest, int idx) { if (digest.length < (idx + 1) * 8) { thrownewIllegalArgumentException("Insufficient length of digest"); }
longhash=0; // 8 bytes digest,a byte is 8 bits like :1321 2432 // each loop choose a byte to calculate hash,and shift i*8 bits for (inti=0; i < 8; i++) { hash |= (255L & (long)digest[i + idx * 8]) << (8 * i); } return hash; }
Implement a virtual node selector.
According to the definition of the consistent hashing algorithm, a virtual node selector needs to generate multiple virtual nodes for the service and map each node to multiple hash values, finally obtaining the nearest node based on the passed hash value and returning it to the caller.
privatestaticclassConsistentHashLoadBalanceSelector { // hash to virtual node list privatefinal TreeMap<Long, String> virtualInvokers;
privateConsistentHashLoadBalanceSelector(List<String> serviceUrlList, int virtualNodeNumber) { this.virtualInvokers = newTreeMap<>(); // generate service address virtual node] // one address may map to multiple virtual nodes // use the md5 hash algorithm to generate the hash value of the // virtual node LogUtil.info("init add serviceUrlList:{}", serviceUrlList); for (String serviceNode : serviceUrlList) { addVirtualNode(serviceNode, virtualNodeNumber); }
}
privatevoidaddVirtualNode(String serviceNode, int virtualNodeNumber) { for (inti=0; i < virtualNodeNumber / 8; i++) { StringvirtualNodeName= serviceNode + "#" + i; byte[] md5Hash = md5Hash(virtualNodeName); // md5Hash have 32 bytes // use 8 byte for each virtual node for (intj=0; j < 4; j++) { Longhash= calculateHash(md5Hash, j); virtualInvokers.put(hash, serviceNode); } } }
public String select(String rpcServiceKey) { byte[] digest = md5Hash(rpcServiceKey); // use first 8 byte to get hash return selectForKey(calculateHash(digest, 0)); }
if (entry == null) { entry = virtualInvokers.firstEntry(); }
return entry.getValue(); }
}
Implement a complete load balancing method
Use the hash of the interface name and the available service list as the key to cache the corresponding consistent hashing selector. If it exists, directly obtain a load node from the existing hash selector. If it does not exist, create a new one.
privatestaticclassConsistentHashLoadBalanceSelector { // hash to virtual node list privatefinal TreeMap<Long, String> virtualInvokers;
privateConsistentHashLoadBalanceSelector(List<String> serviceUrlList, int virtualNodeNumber) { this.virtualInvokers = newTreeMap<>(); // generate service address virtual node] // one address may map to multiple virtual nodes // use the md5 hash algorithm to generate the hash value of the // virtual node LogUtil.info("init add serviceUrlList:{}", serviceUrlList); for (String serviceNode : serviceUrlList) { addVirtualNode(serviceNode, virtualNodeNumber); }
}
privatevoidaddVirtualNode(String serviceNode, int virtualNodeNumber) { for (inti=0; i < virtualNodeNumber / 8; i++) { StringvirtualNodeName= serviceNode + "#" + i; byte[] md5Hash = md5Hash(virtualNodeName); // md5Hash have 32 bytes // use 8 byte for each virtual node for (intj=0; j < 4; j++) { Longhash= calculateHash(md5Hash, j); virtualInvokers.put(hash, serviceNode); } } }
public String select(String rpcServiceKey) { byte[] digest = md5Hash(rpcServiceKey); // use first 8 byte to get hash return selectForKey(calculateHash(digest, 0)); }
protectedstatic Long calculateHash(byte[] digest, int idx) { if (digest.length < (idx + 1) * 8) { thrownewIllegalArgumentException("Insufficient length of digest"); }
longhash=0; // 8 bytes digest,a byte is 8 bits like :1321 2432 // each loop choose a byte to calculate hash,and shift i*8 bits for (inti=0; i < 8; i++) { hash |= (255L & (long)digest[i + idx * 8]) << (8 * i); } return hash; }
/** * Choose one from the list of existing service addresses list * * @param serviceUrlList Service address list * @param rpcRequest * @return */ @Override public String selectServiceAddress(List<String> serviceUrlList, RpcRequest rpcRequest) { intserviceListHash= System.identityHashCode(serviceUrlList); StringinterfaceName= rpcRequest.getServiceName(); StringselectorKey= interfaceName + serviceListHash;
private Channel fetchAndConnectChannel(InetSocketAddress address) { Channelchannel= addressChannelManager.get(address); if (channel == null) { // connect to service to get new address and rebuild the channel channel = connect(address); addressChannelManager.set(address, channel); } return channel; }
private Channel connect(InetSocketAddress address) { CompletableFuture<Channel> completableFuture = newCompletableFuture<>(); bootstrap.connect(address).addListener((ChannelFutureListener)future -> { if (future.isSuccess()) { // set channel to future LogUtil.info("The client has connected [{}] successful!", address.toString()); completableFuture.complete(future.channel()); } else { LogUtil.error("The client failed to connect to the server [{}],future", address.toString(), future); thrownewIllegalStateException(); } }); Channelchannel=null; try { channel = completableFuture.get(); } catch (Exception e) { LogUtil.error("occur exception when connect to server:", e); } return channel; }
/** * Called when an exception occurs in processing a client message */ @Override publicvoidexceptionCaught(ChannelHandlerContext ctx, Throwable cause) { LogUtil.error("server exceptionCaught"); cause.printStackTrace(); ctx.close(); }