0%

Java Interview Guide

自己在准备欧洲的面试过程发现,尽管欧洲的面试相对更简单,但是仍然会询问一些基础的八股文,而国内的面经虽然更全面,并且往往更深层,但是因为是中文的,所以在自己准备的时候还是很不方便,并且2边的问题侧重点还是有所不同的,所以我尝试对自己过去学习到的以及面试中遇到的问题进行整理和归纳,希望能帮助到各位在准备进行英文面试的朋友,也欢迎大家在阅读过程中给出自己的问题或者也提供一些自己曾经遇到的问题,祝大家面试好运。


📚 Table of Contents


🌐 Online Reading

This repository is published as a GitHub Pages site:
📖 Open Online Version

  • Best for mobile/tablet reading

📅 Future Plans

  • Add Design Patterns section
  • Add Elasticsearch and related tools
  • Continuous content updates
  • Feedback and suggestions are welcome! 🙌

📄 License

Apache-2.0 License

基于springBoot,手写一个简单的RPC框架(三)

image-20230522163034434

继续上一章,在实现了服务端注册和调用之后,需要来实现客户端的功能,其中主要包括负载均衡,限流,请求发送和服务发现上。接下来将从一个RPC调用流程的顺序来实现接下来的功能

一次请求:

​ 实现客户端之前,首先需要想清楚一次请求需要发送些什么。

​ 首先,需要当前的服务名方法名,以及对应的参数和参数类型,否则服务端无法根据请求来进行对应的反射调用。

​ 其次,请求中应该要带上@RpcConsumer内的参数,让服务端能够找到正确的服务。

​ 最后,请求中应该带上本次请求的一个唯一值,以方便链路追踪。

​ 至此,一个请求需要的基本参数已经完成。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
@Data
@AllArgsConstructor
@NoArgsConstructor
@Builder
public class RpcRequest implements Serializable {

private static final long serialVersionUID = 8509587559718339795L;
/**
* traceId
*/
private String traceId;
/**
* interface name
*/
private String serviceName;
/**
* method name
*/
private String methodName;
/**
* parameters
*/
private Object[] parameters;
/**
* parameter types
*/
private Class<?>[] paramTypes;
/**
* version
*/
private String version;
/**
* group
*/
private String project;

private String group;

/**
* generate service name,use to distinguish different service,and * can be
* split to get the service name
*/
public String fetchRpcServiceName() {
return this.getProject() + "*" + this.getGroup() + "*" + this.getServiceName() + "*" + this.getVersion();
}

}

服务代理

​ 第一步,在spring启动的过程中,扫描所有带有@RpcConsumer 为其生成代理,后续调用到该类方法的时候都会调用代理的方法,发起请求。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
@Component
public class RpcBeanPostProcessor implements BeanPostProcessor {

private final RpcServiceRegistryAdapter adapter;

private final RpcSendingServiceAdapter sendingServiceAdapter;

public RpcBeanPostProcessor() {
this.adapter = SingletonFactory.getInstance(RpcServiceRegistryAdapterImpl.class);;
this.sendingServiceAdapter = ExtensionLoader.getExtensionLoader(RpcSendingServiceAdapter.class)
.getExtension(RpcRequestSendingEnum.NETTY.getName());
}

/**
* register service
*
* @param bean
* @param beanName
* @return
* @throws BeansException
*/
@Override
public Object postProcessBeforeInitialization(Object bean, String beanName) throws BeansException {
LogUtil.info("start process register service: {}", bean);
// register service
if (bean.getClass().isAnnotationPresent(RpcProvider.class)) {
RpcProvider annotation = bean.getClass().getAnnotation(RpcProvider.class);
// build rpc service config
RpcServiceConfig serviceConfig = RpcServiceConfig.builder()
.service(bean)
.project(annotation.project())
.version(annotation.version())
.group(annotation.group())
.build();
LogUtil.info("register service: {}", serviceConfig);
adapter.registryService(serviceConfig);
}
return bean;
}

/**
* proxy and injection of consumers
*
* @param bean
* @param beanName
* @return
* @throws BeansException
*/
@Override
public Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException {
Class<?> toBeProcessedBean = bean.getClass();
Field[] declaredFields = toBeProcessedBean.getDeclaredFields();
for (Field declaredField : declaredFields) {
if (declaredField.isAnnotationPresent(RpcConsumer.class)) {
RpcConsumer annotation = declaredField.getAnnotation(RpcConsumer.class);
// build rpc service config
RpcServiceConfig serviceConfig = RpcServiceConfig.builder()
.project(annotation.project())
.version(annotation.version())
.group(annotation.group())
.build();
// create the proxy bean Factory and the proxy bean
RpcServiceProxy proxy = new RpcServiceProxy(sendingServiceAdapter, serviceConfig);
Object rpcProxy = proxy.getProxy(declaredField.getType());
declaredField.setAccessible(true);
try {
LogUtil.info("create service proxy: {}", bean);
declaredField.set(bean, rpcProxy);
} catch (IllegalAccessException e) {
e.printStackTrace();
}
}
}
return bean;
}
}

接下来就是在代理类的invoke方法中实现对request的拼装和调用。同时获取Future中的响应值返回给调用者。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
public class RpcServiceProxy implements InvocationHandler {

private final RpcSendingServiceAdapter sendingServiceAdapter;

private final RpcServiceConfig config;

public RpcServiceProxy(RpcSendingServiceAdapter sendingServiceAdapter, RpcServiceConfig config) {
this.sendingServiceAdapter = sendingServiceAdapter;
this.config = config;
}

@Override
public Object invoke(Object proxy, Method method, Object[] args) {
LogUtil.info("invoked method: [{}]", method.getName());
RpcRequest rpcRequest = buildRequest(method,args);

RpcResponse<Object> rpcResponse = null;
CompletableFuture<RpcResponse<Object>> completableFuture =
(CompletableFuture<RpcResponse<Object>>)sendingServiceAdapter.sendRpcRequest(rpcRequest);
try {
rpcResponse = completableFuture.get();
return rpcResponse.getData();
} catch (Exception e) {
LogUtil.error("occur exception:", e);
}
return null;
}

/**
* get the proxy object
*/
@SuppressWarnings("unchecked")
public <T> T getProxy(Class<T> clazz) {
return (T)Proxy.newProxyInstance(clazz.getClassLoader(), new Class<?>[] {clazz}, this);
}

private RpcRequest buildRequest(Method method,Object[] args){
RpcRequest rpcRequest = RpcRequest.builder()
.methodName(method.getName())
.parameters(args)
.serviceName(method.getDeclaringClass().getName())
.paramTypes(method.getParameterTypes())
.traceId(UUID.randomUUID().toString())
.project(config.getProject())
.version(config.getVersion())
.group(config.getGroup())
.build();
return rpcRequest;
}
}

发送请求:

​ 客户端的核心方法为发送请求,请求的发送有多种方法,这里仅基于netty的Nio进行了实现。以下是一个完整的时序。

WX20230530-145829@2x

​ 首先实现发送方法,里面应该包含寻找地址,发送请求的功能。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
public class RpcSendingServiceAdapterImpl implements RpcSendingServiceAdapter {

/**
* EventLoopGroup is a multithreaded event loop that handles I/O operation.
*/
private final EventLoopGroup eventLoopGroup;

/**
* Bootstrap helt setting and start netty client
*/
private final Bootstrap bootstrap;

/**
* Service discovery
*/
private final RpcServiceFindingAdapter findingAdapter;

/**
* Channel manager,mapping channel and address
*/
private final AddressChannelManager addressChannelManager;

/**
* Waiting process request queue
*/
private final WaitingProcessRequestQueue waitingProcessRequestQueue;

public RpcSendingServiceAdapterImpl() {
this.findingAdapter = ExtensionLoader.getExtensionLoader(RpcServiceFindingAdapter.class)
.getExtension(ServiceDiscoveryEnum.ZK.getName());
this.addressChannelManager = SingletonFactory.getInstance(AddressChannelManager.class);
this.waitingProcessRequestQueue = SingletonFactory.getInstance(WaitingProcessRequestQueue.class);
// initialize
eventLoopGroup = new NioEventLoopGroup();
bootstrap = new Bootstrap();
bootstrap.group(eventLoopGroup)
.channel(NioSocketChannel.class)
.handler(new LoggingHandler(LogLevel.INFO))
// The timeout period for the connection.
// If this time is exceeded or if the connection cannot be
// established, the connection fails.
.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 5000)
.handler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(SocketChannel ch) {
ChannelPipeline p = ch.pipeline();
// If no data is sent to the server within 15 seconds, a
// heartbeat request is sent
p.addLast(new IdleStateHandler(0, 5, 0, TimeUnit.SECONDS));
p.addLast(new RpcMessageEncoder());
p.addLast(new RpcMessageDecoder());
p.addLast(new NettyRpcClientHandler());
}
});
}

@Override
public Object sendRpcRequest(RpcRequest rpcRequest) {
CompletableFuture<RpcResponse<Object>> result = new CompletableFuture<>();
InetSocketAddress address = findServiceAddress(rpcRequest);
Channel channel = fetchAndConnectChannel(address);
if (channel.isActive()) {
addToProcessQueue(rpcRequest.getTraceId(), result);
RpcData rpcData = prepareRpcData(rpcRequest);
sendRpcData(channel, rpcData, result);
} else {
log.error("Send request[{}] failed", rpcRequest);
throw new IllegalStateException();
}
return result;
}
private InetSocketAddress findServiceAddress(RpcRequest rpcRequest) {
return findingAdapter.findServiceAddress(rpcRequest);
}

private void addToProcessQueue(String traceId, CompletableFuture<RpcResponse<Object>> result) {
waitingProcessRequestQueue.put(traceId, result);
}

private RpcData prepareRpcData(RpcRequest rpcRequest) {
return RpcData.builder()
.data(rpcRequest)
.serializeMethodCodec(SerializationTypeEnum.HESSIAN.getCode())
.compressType(CompressTypeEnum.GZIP.getCode())
.messageType(RpcConstants.REQUEST_TYPE)
.build();
}
private void sendRpcData(Channel channel, RpcData rpcData, CompletableFuture<RpcResponse<Object>> result) {
channel.writeAndFlush(rpcData).addListener((ChannelFutureListener)future -> {
if (future.isSuccess()) {
LogUtil.info("client send message: [{}]", rpcData);
} else {
future.channel().close();
result.completeExceptionally(future.cause());
LogUtil.error("Send failed:", future.cause());
}
});
}

private Channel fetchAndConnectChannel(InetSocketAddress address) {
Channel channel = addressChannelManager.get(address);
if (channel == null) {
// connect to service to get new address and rebuild the channel
channel = connect(address);
addressChannelManager.set(address, channel);
}
return channel;
}

private Channel connect(InetSocketAddress address) {
CompletableFuture<Channel> completableFuture = new CompletableFuture<>();
bootstrap.connect(address).addListener((ChannelFutureListener)future -> {
if (future.isSuccess()) {
// set channel to future
LogUtil.info("The client has connected [{}] successful!", address.toString());
completableFuture.complete(future.channel());
} else {
LogUtil.error("The client failed to connect to the server [{}],future", address.toString(), future);
throw new IllegalStateException();
}
});
Channel channel = null;
try {
channel = completableFuture.get();
} catch (Exception e) {
LogUtil.error("occur exception when connect to server:", e);
}
return channel;
}

public Channel getChannel(InetSocketAddress inetSocketAddress) {
Channel channel = addressChannelManager.get(inetSocketAddress);
if (channel == null) {
channel = connect(inetSocketAddress);
addressChannelManager.set(inetSocketAddress, channel);
}
return channel;
}
}

​ 在这个类中,核心的方法是sendRpcRequest,他负责获取服务,创建链接,创建一个Future任务,并且发送请求。

发现服务

发现服务的流程可以包括:

1.从注册中心中拉取服务地址列表

2.通过负载均衡算法获取服务具体类型。

获取地址

下面先实现第一步(此处可以使用缓存进行进一步的优化,本项目中的zk使用了一个ConcurrentHashMap来代替缓存,详细代码可以见CuratorClient):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
public class RpcServiceFindingAdapterImpl implements RpcServiceFindingAdapter {

private final LoadBalanceService loadBalanceService;

public RpcServiceFindingAdapterImpl() {
this.loadBalanceService = ExtensionLoader.getExtensionLoader(LoadBalanceService.class).getExtension(LOAD_BALANCE);
}

@Override
public InetSocketAddress findServiceAddress(RpcRequest rpcRequest) {
String serviceName = rpcRequest.fetchRpcServiceName();
CuratorFramework zkClient = CuratorClient.getZkClient();
List<String> serviceAddresseList = CuratorClient.getChildrenNodes(zkClient, serviceName);
if (CollectionUtils.isEmpty(serviceAddresseList)) {
throw new RuntimeException("no service available, serviceName: " + serviceName);
}

String service = loadBalanceService.selectServiceAddress(serviceAddresseList, rpcRequest);
if (StringUtils.isBlank(service)) {
throw new RuntimeException("no service available, serviceName: " + serviceName);
}
String[] socketAddressArray = service.split(":");
String host = socketAddressArray[0];
int port = Integer.parseInt(socketAddressArray[1]);
return new InetSocketAddress(host, port);
}
}
负载均衡——一致性哈希算法
定义:

一致性哈希算法是一种用于分布式系统中数据分片和负载均衡的算法。它通过引入虚拟节点和哈希环的概念,实现了节点的动态扩缩容时最小化数据迁移的需求,提高了系统的稳定性和性能。它在分布式缓存、负载均衡等场景中被广泛应用。

实现:

哈希值计算

首先,根据一致性哈希算法我们需要有根据对应的服务生成哈希值。在以下实现中,首先将输入通过SHA-256算法产生一个32字节(256位)的哈希值

但是这样的哈希值过长,并不方便处理,所以我们需要将他进行缩短。同时,一个节点映射多个哈希可以提高一致性哈希算法的分布均匀性,因为每个节点都会在哈希空间中拥有多个哈希值,这可以帮助减少因节点增加或减少而导致的哈希空间重分布的影响。

calculateHash 会对已经得到的256为哈希值从起点j开始向后取8字节生成一个新的Long类型的哈希值。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
protected static byte[] md5Hash(String input) {
MessageDigest messageDigest = null;
try {
messageDigest = MessageDigest.getInstance("SHA-256");
byte[] hashBytes = messageDigest.digest(input.getBytes(StandardCharsets.UTF_8));
messageDigest.update(hashBytes);
return messageDigest.digest();
} catch (NoSuchAlgorithmException e) {
LogUtil.error("No such algorithm exception: {}", e.getMessage());
throw new RuntimeException(e);
}

}

protected static Long calculateHash(byte[] digest, int idx) {
if (digest.length < (idx + 1) * 8) {
throw new IllegalArgumentException("Insufficient length of digest");
}

long hash = 0;
// 8 bytes digest,a byte is 8 bits like :1321 2432
// each loop choose a byte to calculate hash,and shift i*8 bits
for (int i = 0; i < 8; i++) {
hash |= (255L & (long)digest[i + idx * 8]) << (8 * i);
}
return hash;
}

实现一个虚拟节点选择器。

根据一致性哈希算法的定义,一个虚拟节点选择器需要将服务生成多个虚拟节点,并且将每个节点映射为多个哈希值,最后根据传入的哈希值获取最近的节点返回给调用者。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
private static class ConsistentHashLoadBalanceSelector {
// hash to virtual node list
private final TreeMap<Long, String> virtualInvokers;

private ConsistentHashLoadBalanceSelector(List<String> serviceUrlList, int virtualNodeNumber) {
this.virtualInvokers = new TreeMap<>();
// generate service address virtual node]
// one address may map to multiple virtual nodes
// use the md5 hash algorithm to generate the hash value of the
// virtual node
LogUtil.info("init add serviceUrlList:{}", serviceUrlList);
for (String serviceNode : serviceUrlList) {
addVirtualNode(serviceNode, virtualNodeNumber);
}

}

private void addVirtualNode(String serviceNode, int virtualNodeNumber) {
for (int i = 0; i < virtualNodeNumber / 8; i++) {
String virtualNodeName = serviceNode + "#" + i;
byte[] md5Hash = md5Hash(virtualNodeName);
// md5Hash have 32 bytes
// use 8 byte for each virtual node
for (int j = 0; j < 4; j++) {
Long hash = calculateHash(md5Hash, j);
virtualInvokers.put(hash, serviceNode);
}
}
}

public String select(String rpcServiceKey) {
byte[] digest = md5Hash(rpcServiceKey);
// use first 8 byte to get hash
return selectForKey(calculateHash(digest, 0));
}

public String selectForKey(long hashCode) {
Map.Entry<Long, String> entry = virtualInvokers.tailMap(hashCode, true).firstEntry();

if (entry == null) {
entry = virtualInvokers.firstEntry();
}

return entry.getValue();
}

}

实现完整负载均衡方法

将接口名和可用服务列表的哈希作为key,缓存对应的一致性哈希选择器,如果存在则直接从已有的哈希选择器中获得一个负载节点,如果不存在,则新建一个。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
public class ConsistentHashLoadBalanceService implements LoadBalanceService {

private final Map<String, ConsistentHashLoadBalanceSelector> serviceToSelectorMap = new ConcurrentHashMap<>();

private static class ConsistentHashLoadBalanceSelector {
// hash to virtual node list
private final TreeMap<Long, String> virtualInvokers;

private ConsistentHashLoadBalanceSelector(List<String> serviceUrlList, int virtualNodeNumber) {
this.virtualInvokers = new TreeMap<>();
// generate service address virtual node]
// one address may map to multiple virtual nodes
// use the md5 hash algorithm to generate the hash value of the
// virtual node
LogUtil.info("init add serviceUrlList:{}", serviceUrlList);
for (String serviceNode : serviceUrlList) {
addVirtualNode(serviceNode, virtualNodeNumber);
}

}

private void addVirtualNode(String serviceNode, int virtualNodeNumber) {
for (int i = 0; i < virtualNodeNumber / 8; i++) {
String virtualNodeName = serviceNode + "#" + i;
byte[] md5Hash = md5Hash(virtualNodeName);
// md5Hash have 32 bytes
// use 8 byte for each virtual node
for (int j = 0; j < 4; j++) {
Long hash = calculateHash(md5Hash, j);
virtualInvokers.put(hash, serviceNode);
}
}
}

public String select(String rpcServiceKey) {
byte[] digest = md5Hash(rpcServiceKey);
// use first 8 byte to get hash
return selectForKey(calculateHash(digest, 0));
}

public String selectForKey(long hashCode) {
Map.Entry<Long, String> entry = virtualInvokers.tailMap(hashCode, true).firstEntry();

if (entry == null) {
entry = virtualInvokers.firstEntry();
}

return entry.getValue();
}

}

protected static byte[] md5Hash(String input) {
MessageDigest messageDigest = null;
try {
messageDigest = MessageDigest.getInstance("SHA-256");
byte[] hashBytes = messageDigest.digest(input.getBytes(StandardCharsets.UTF_8));
messageDigest.update(hashBytes);
return messageDigest.digest();
} catch (NoSuchAlgorithmException e) {
LogUtil.error("No such algorithm exception: {}", e.getMessage());
throw new RuntimeException(e);
}

}

protected static Long calculateHash(byte[] digest, int idx) {
if (digest.length < (idx + 1) * 8) {
throw new IllegalArgumentException("Insufficient length of digest");
}

long hash = 0;
// 8 bytes digest,a byte is 8 bits like :1321 2432
// each loop choose a byte to calculate hash,and shift i*8 bits
for (int i = 0; i < 8; i++) {
hash |= (255L & (long)digest[i + idx * 8]) << (8 * i);
}
return hash;
}

/**
* Choose one from the list of existing service addresses list
*
* @param serviceUrlList Service address list
* @param rpcRequest
* @return
*/
@Override
public String selectServiceAddress(List<String> serviceUrlList, RpcRequest rpcRequest) {
int serviceListHash = System.identityHashCode(serviceUrlList);
String interfaceName = rpcRequest.getServiceName();
String selectorKey = interfaceName + serviceListHash;

ConsistentHashLoadBalanceSelector consistentHashLoadBalanceSelector = serviceToSelectorMap
.computeIfAbsent(selectorKey, key -> new ConsistentHashLoadBalanceSelector(serviceUrlList, VIRTUAL_NODES));

return consistentHashLoadBalanceSelector.select(interfaceName + Arrays.stream(rpcRequest.getParameters()));
}

}

发送请求

发送请求
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
  @Override
public Object sendRpcRequest(RpcRequest rpcRequest) {
CompletableFuture<RpcResponse<Object>> result = new CompletableFuture<>();
InetSocketAddress address = findServiceAddress(rpcRequest);
Channel channel = fetchAndConnectChannel(address);
if (channel.isActive()) {
addToProcessQueue(rpcRequest.getTraceId(), result);
RpcData rpcData = prepareRpcData(rpcRequest);
sendRpcData(channel, rpcData, result);
} else {
log.error("Send request[{}] failed", rpcRequest);
throw new IllegalStateException();
}
return result;
}

private void addToProcessQueue(String traceId, CompletableFuture<RpcResponse<Object>> result) {
waitingProcessRequestQueue.put(traceId, result);
}

private RpcData prepareRpcData(RpcRequest rpcRequest) {
return RpcData.builder()
.data(rpcRequest)
.serializeMethodCodec(SerializationTypeEnum.HESSIAN.getCode())
.compressType(CompressTypeEnum.GZIP.getCode())
.messageType(RpcConstants.REQUEST_TYPE)
.build();
}
private void sendRpcData(Channel channel, RpcData rpcData, CompletableFuture<RpcResponse<Object>> result) {
channel.writeAndFlush(rpcData).addListener((ChannelFutureListener)future -> {
if (future.isSuccess()) {
LogUtil.info("client send message: [{}]", rpcData);
} else {
future.channel().close();
result.completeExceptionally(future.cause());
LogUtil.error("Send failed:", future.cause());
}
});
}
使用channel链接服务器
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
private Channel fetchAndConnectChannel(InetSocketAddress address) {
Channel channel = addressChannelManager.get(address);
if (channel == null) {
// connect to service to get new address and rebuild the channel
channel = connect(address);
addressChannelManager.set(address, channel);
}
return channel;
}

private Channel connect(InetSocketAddress address) {
CompletableFuture<Channel> completableFuture = new CompletableFuture<>();
bootstrap.connect(address).addListener((ChannelFutureListener)future -> {
if (future.isSuccess()) {
// set channel to future
LogUtil.info("The client has connected [{}] successful!", address.toString());
completableFuture.complete(future.channel());
} else {
LogUtil.error("The client failed to connect to the server [{}],future", address.toString(), future);
throw new IllegalStateException();
}
});
Channel channel = null;
try {
channel = completableFuture.get();
} catch (Exception e) {
LogUtil.error("occur exception when connect to server:", e);
}
return channel;
}

Consumer进行返回值处理

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
public class NettyRpcClientHandler extends SimpleChannelInboundHandler<RpcData> {

private final RpcSendingServiceAdapterImpl adapter;

private final WaitingProcessRequestQueue waitingProcessRequestQueue;

public NettyRpcClientHandler() {
this.adapter = SingletonFactory.getInstance(RpcSendingServiceAdapterImpl.class);
this.waitingProcessRequestQueue = SingletonFactory.getInstance(WaitingProcessRequestQueue.class);
}

/**
* heart beat handle
*
* @param ctx
* @param evt
* @throws Exception
*/
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
// if the channel is free,close it
if (evt instanceof IdleStateEvent) {
IdleState state = ((IdleStateEvent)evt).state();
if (state == IdleState.WRITER_IDLE) {
LogUtil.info("write idle happen [{}]", ctx.channel().remoteAddress());
Channel channel = adapter.getChannel((InetSocketAddress)ctx.channel().remoteAddress());
RpcData rpcData = new RpcData();
rpcData.setSerializeMethodCodec(SerializationTypeEnum.HESSIAN.getCode());
rpcData.setCompressType(CompressTypeEnum.GZIP.getCode());
rpcData.setMessageType(RpcConstants.HEARTBEAT_REQUEST_TYPE);
rpcData.setData(RpcConstants.PING);
channel.writeAndFlush(rpcData).addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
}
} else {
super.userEventTriggered(ctx, evt);
}
}

/**
* Called when an exception occurs in processing a client message
*/
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
LogUtil.error("server exceptionCaught");
cause.printStackTrace();
ctx.close();
}

@Override
protected void channelRead0(ChannelHandlerContext ctx, RpcData rpcData) throws Exception {
LogUtil.info("Client receive message: [{}]", rpcData);
RpcData rpcMessage = new RpcData();
setupRpcMessage(rpcMessage);

if (rpcData.isHeartBeatResponse()) {
LogUtil.info("heart [{}]", rpcMessage.getData());
} else if (rpcData.isResponse()) {
RpcResponse<Object> rpcResponse = (RpcResponse<Object>)rpcData.getData();
waitingProcessRequestQueue.complete(rpcResponse);
}
}

private void setupRpcMessage(RpcData rpcMessage) {
rpcMessage.setSerializeMethodCodec(SerializationTypeEnum.HESSIAN.getCode());
rpcMessage.setCompressType(CompressTypeEnum.GZIP.getCode());
}

}

java spi 应用

Spi 定义

在Java中,SPI代表Service Provider Interface(服务提供者接口)。SPI是一种机制,允许应用程序通过在类路径中发现和加载可插拔的组件或服务提供者来扩展其功能。它提供了一种松耦合的方式,允许开发人员编写可以与多个实现进行交互的代码,而无需显式地引用特定的实现类。

Api与Spi区别

API是用于定义软件组件之间的交互规则和约定,而SPI是一种机制,用于实现可插拔的组件或服务的扩展性。

API用于暴露功能和功能,供其他开发人员使用和集成,而SPI用于动态加载和使用可替换的组件实现。

API是面向开发人员的,提供了编程接口和文档,以便正确使用和集成软件组件。SPI是面向开发人员和框架/应用程序的,用于扩展框架或应用程序的功能。

API是被调用方定义的,它规定了调用方式和参数。SPI是调用方定义的,它允许被调用方提供实现。

Spi机制

WX20230605-150555@2x

实现过程

WX20230605-151041@2x

在Rpc中的实现

​ 在简单的RPC框架实现中基于Dubbo实现了一个基于注解的SPI机制,这里将简单介绍下原理。

实现:

1

定义一个SPI注解,用于标注所有需要进行SPI注册的接口

1
2
3
4
5
@Documented
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.TYPE)
public @interface SPI {
}

其次,由于为了防止并发冲突问题,这里使用包装类对实例接口进行包装,用于保证多线程过程中的数据安全。

1
2
3
4
5
6
7
8
9
10
11
12
public class Holder<T> {

private volatile T value;

public T get() {
return value;
}

public void set(T value) {
this.value = value;
}
}

holder可以作为一个锁对象保证安全。

参考spring spi ,extension的配置文件定义为以下格式key = ‘full path’,如zk=org.example.ray.infrastructure.adapter.impl.RpcServiceFindingAdapterImpl

这样可以方便在系统内通过key创建获取和调用

2

为了能够避免重复加载创建的问题,使用map作为缓存。同时为了提高获取的效率,针对不同服务的服务,会有不同的拓展实例,缓存不同服务下的实例。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
public final class ExtensionLoader<T> {
//extention path
private static final String SERVICE_DIRECTORY = "META-INF/extension/";
//save extentionloader for load different class
private static final Map<Class<?>, ExtensionLoader<?>> EXTENSION_LOADERS = new ConcurrentHashMap<>();
//save all instance
private static final Map<Class<?>, Object> EXTENSION_INSTANCES = new ConcurrentHashMap<>();

private final Class<?> type;
//save different service hold instance
private final Map<String, Holder<Object>> cachedInstances = new ConcurrentHashMap<>();
//save different instance
private final Holder<Map<String, Class<?>>> cachedClasses = new Holder<>();

private ExtensionLoader(Class<?> type) {
this.type = type;
}
}

3

实现获取classloader和getExtension方法

WX20230605-170936@2x

基于springBoot,手写一个简单的RPC框架(二)

image-20230522163034434

继续上一章,实现了服务注册后需要实现服务调用。

服务执行

一个RPC的服务调用应该分为以下几步:

请求监听;

解码请求;

方法调用;

返回结果;

接下来将依次实现以上功能;

请求监听

需要定义一个RpcRequest请求类,由于后续处理

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
@Data
@AllArgsConstructor
@NoArgsConstructor
@Builder
public class RpcRequest implements Serializable {

private static final long serialVersionUID = 8509587559718339795L;
/**
* traceId
*/
private String traceId;
/**
* interface name
*/
private String serviceName;
/**
* method name
*/
private String methodName;
/**
* parameters
*/
private Object[] parameters;
/**
* parameter types
*/
private Class<?>[] paramTypes;
/**
* version
*/
private String version;
/**
* group
*/
private String project;

private String group;

/**
* generate service name,use to distinguish different service,and * can be split to get the service name
*/
public String fetchRpcServiceName() {
return this.getProject() +"*"+this.getGroup()+"*"+ this.getServiceName() +"*"+ this.getVersion();
}

}

监听请求需要启动一个netty server,用于监听请求service的服务。

启动时首先需要关闭之前注册的服务等资源。

随后对netty需要的资源进行依次初始化。

以下是一段netty的启动代码,其中需要加入编码和解码器用于协议解析,探活。

同时,需要加入限流和解码后的请求处理hanlder

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
@Component
public class NettyServer {

public NettyServer() {

}

public void start() {
LogUtil.info("netty server init");

ServerShutdownHook.getInstance().registerShutdownHook();

EventLoopGroup listenerGroup = initListenerGroup();
EventLoopGroup workerGroup = initWorkerGroup();
DefaultEventExecutorGroup businessGroup = initBusinessGroup();

LogUtil.info("netty server start");

try {
ServerBootstrap serverBootstrap = configureServerBootstrap(listenerGroup, workerGroup, businessGroup);
bindAndListen(serverBootstrap);
} catch (Exception e) {
LogUtil.error("occur exception when start server:", e);
} finally {
shutdown(listenerGroup, workerGroup, businessGroup);
}

}

private EventLoopGroup initListenerGroup() {
return new NioEventLoopGroup(1);
}

private EventLoopGroup initWorkerGroup() {
return new NioEventLoopGroup();
}

private DefaultEventExecutorGroup initBusinessGroup() {
return new DefaultEventExecutorGroup(
Runtime.getRuntime().availableProcessors() * 2,
ThreadPoolFactoryUtil.createThreadFactory("netty-server-business-group", false)
);
}

private ServerBootstrap configureServerBootstrap(EventLoopGroup listenerGroup, EventLoopGroup workerGroup, DefaultEventExecutorGroup businessGroup) {
ServerBootstrap serverBootstrap = new ServerBootstrap();
serverBootstrap.group(listenerGroup, workerGroup)
.channel(NioServerSocketChannel.class)
.childOption(ChannelOption.SO_KEEPALIVE, true)
.childOption(ChannelOption.TCP_NODELAY, true)
.option(ChannelOption.SO_BACKLOG, 128)
.handler(new LoggingHandler(LogLevel.INFO))
.childHandler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(SocketChannel socketChannel) throws Exception {
ChannelPipeline pipeline = socketChannel.pipeline();
pipeline.addLast(new IdleStateHandler(30, 0, 0, TimeUnit.SECONDS));
pipeline.addLast(new RpcMessageEncoder());
pipeline.addLast(new RpcMessageDecoder());
pipeline.addLast(new DefaultTrafficBlockHandler());
pipeline.addLast(businessGroup, new NettyRpcServerHandler());
}
});

return serverBootstrap;
}

private void bindAndListen(ServerBootstrap serverBootstrap) throws UnknownHostException, InterruptedException {
LogUtil.info("netty server bind port:{} " , PropertiesFileUtil.readPortFromProperties());
String host = InetAddress.getLocalHost().getHostAddress();
ChannelFuture f = serverBootstrap.bind(host, PropertiesFileUtil.readPortFromProperties()).sync();
f.channel().closeFuture().sync();
}

private void shutdown(EventLoopGroup listenerGroup, EventLoopGroup workerGroup, DefaultEventExecutorGroup businessGroup) {
listenerGroup.shutdownGracefully();
workerGroup.shutdownGracefully();
businessGroup.shutdownGracefully();
}

}

@NoArgsConstructor(access = AccessLevel.PRIVATE)
public class ServerShutdownHook {

private static final ServerShutdownHook INSTANCE = new ServerShutdownHook();

public static ServerShutdownHook getInstance() {
return INSTANCE;
}

/**
* register shut down hook
*/
public void registerShutdownHook() {
Runtime.getRuntime().addShutdownHook(new Thread(() -> {
// 执行清理操作
clearAll();
}));
}

private void clearAll() {
try {
// 清理注册表
InetSocketAddress inetSocketAddress = new InetSocketAddress(InetAddress.getLocalHost().getHostAddress(), PropertiesFileUtil.readPortFromProperties());
CuratorClient.clearRegistry(CuratorClient.getZkClient(), inetSocketAddress);
} catch (Exception ignored) {

}
// 关闭线程池
ThreadPoolFactoryUtil.shutDownAllThreadPool();
}

}

结合ApplicationRunner,实现server的自动启动

1
2
3
4
5
6
7
8
9
10
11
12
13
@Component
public class NettyServerRunner implements ApplicationRunner {

@Autowired
private NettyServer nettyServer;

public NettyServerRunner() {}

@Override
public void run(ApplicationArguments args) throws Exception {
nettyServer.start();
}
}

序列化

本项目默认只是实现了hessen的序列化和gzip加解压,这部分有许多的教程,所以在这里介绍。具体的代码可以在源码的org.example.ray.infrastructure.serialize包和org.example.ray.infrastructure.compress包中找到

编码与协议

实现了服务后,我们需要依次为他补充编码和处理类。

在实现编码的服务之前,首先应该确定底层的编码协议。

协议

本项目参考一些已有的协议设计,选择了一种比较简单的协议设计方式,如下图所示:

image-20230522174913202

协议由一个16byte的header和body组成。

其中0-4是magic code,用于校验

4-5为自定义的协议版本

5-8是整个message的长度,用于解码

8-9定义了消息类型,包括请求,响应,心跳请求,心跳响应。

10为编码方式

11为压缩方式

12-16为一个整型,为请求的编号

Java pojo如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
@Data
@NoArgsConstructor
@AllArgsConstructor
@Builder
public class RpcData {
/**
* rpc message type
*/
private byte messageType;
/**
* serialization type
*/
private byte serializeMethodCodec;
/**
* compress type
*/
private byte compressType;
/**
* request id
*/
private int requestId;
/**
* request data
*/
private Object data;

public boolean isHeatBeatRequest() {
return messageType == RpcConstants.HEARTBEAT_REQUEST_TYPE;
}

public boolean canSendRequest() {
return messageType != RpcConstants.HEARTBEAT_REQUEST_TYPE
&& messageType != RpcConstants.HEARTBEAT_RESPONSE_TYPE;
}

public boolean isHeartBeatResponse() {
return messageType == RpcConstants.HEARTBEAT_RESPONSE_TYPE;
}

public boolean isResponse() {
return messageType == RpcConstants.RESPONSE_TYPE;
}
}

在了解了协议之后,实现解码

解码

LengthFieldBasedFrameDecoder解码器可以参考以下文章

1
https://zhuanlan.zhihu.com/p/95621344"

在了解LengthFieldBasedFrameDecoder解码器的基础上,解码的过程其实并不复杂。主要是解码header,校验,和解码body3部分,具体实现可以参考代码和注释。

解码部分使用java spi,可以定制选择反序列化和解压方法,此部分可以参考github中的代码,或者可以只使用固定序列化和解压方法替代spi部分。

本项目默认只是实现了hessen的序列化和gzip加解压,这部分有许多的教程,所以在这里介绍。具体的代码可以在源码的org.example.ray.infrastructure.serialize包和org.example.ray.infrastructure.compress包中找到

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
public class RpcMessageDecoder extends LengthFieldBasedFrameDecoder {

public RpcMessageDecoder() {
// lengthFieldOffset: magic code is 4B, and version is 1B, and then full
// length. so value is 5
// lengthFieldLength: full length is 4B. so value is 4
// lengthAdjustment: full length include all data and read 9 bytes
// before, so the left length is (fullLength-9). so values is -9
// initialBytesToStrip: we will check magic code and version manually,
// so do not strip any bytes. so values is 0
this(8 * 1024 * 1024, 5, 4, -9, 0);
}

public RpcMessageDecoder(int maxFrameLength, int lengthFieldOffset, int lengthFieldLength, int lengthAdjustment,
int initialBytesToStrip) {
super(maxFrameLength, lengthFieldOffset, lengthFieldLength, lengthAdjustment, initialBytesToStrip);
}

@Override
protected Object decode(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
// get the bytebuf which contains the frame
Object decode = super.decode(ctx, in);
if (decode instanceof ByteBuf) {
ByteBuf byteBuf = (ByteBuf)decode;
// if data not empty, decode it
if (byteBuf.readableBytes() >= RpcConstants.HEAD_LENGTH) {
try {
return decode(byteBuf);
} catch (Exception e) {
LogUtil.error("Decode error:{} ,input:{}", e, byteBuf);
} finally {
byteBuf.release();
}
}
}
return decode;
}

/**
* read byte array from byteBuf
*
* @param byteBuf
* @return
*/
private Object decode(ByteBuf byteBuf) {
LogUtil.info("start decode");
checkMagicCode(byteBuf);
checkVersion(byteBuf);

int fullLength = byteBuf.readInt();
RpcData rpcMessage = decodeRpcMessage(byteBuf);

if (rpcMessage.isHeatBeatRequest()) {
return handleHeatBeatRequest(rpcMessage);
}

if (rpcMessage.isHeartBeatResponse()) {
return handleHeartBeatResponse(rpcMessage);
}

return handleNormalRequest(rpcMessage, byteBuf, fullLength);
}

private RpcData decodeRpcMessage(ByteBuf byteBuf) {
LogUtil.info("start decode RpcMessage data");
byte messageType = byteBuf.readByte();
byte codec = byteBuf.readByte();
byte compress = byteBuf.readByte();
int traceId = byteBuf.readInt();

return RpcData.builder()
.serializeMethodCodec(codec)
.traceId(traceId)
.compressType(compress)
.messageType(messageType)
.build();
}

private RpcData handleHeatBeatRequest(RpcData rpcMessage) {
rpcMessage.setData(RpcConstants.PING);
return rpcMessage;
}

private RpcData handleHeartBeatResponse(RpcData rpcMessage) {
rpcMessage.setData(RpcConstants.PONG);
return rpcMessage;
}

private Object handleNormalRequest(RpcData rpcMessage, ByteBuf byteBuf, int fullLength) {
int bodyLength = fullLength - RpcConstants.HEAD_LENGTH;
if (bodyLength <= 0) {
return rpcMessage;
}
return decodeBody(rpcMessage, byteBuf, bodyLength);
}

private RpcData decodeBody(RpcData rpcMessage, ByteBuf byteBuf, Integer bodyLength) {
LogUtil.info("start decode body");
byte[] bodyBytes = new byte[bodyLength];
byteBuf.readBytes(bodyBytes);
// decompose
String compressName = CompressTypeEnum.getName(rpcMessage.getCompressType());
CompressService extension =
ExtensionLoader.getExtensionLoader(CompressService.class).getExtension(compressName);
bodyBytes = extension.decompress(bodyBytes);
// deserialize
if (rpcMessage.getMessageType() == RpcConstants.REQUEST_TYPE) {
RpcRequest rpcRequest = ExtensionLoader.getExtensionLoader(SerializationService.class)
.getExtension(SerializationTypeEnum.getName(rpcMessage.getSerializeMethodCodec()))
.deserialize(bodyBytes, RpcRequest.class);
rpcMessage.setData(rpcRequest);
} else {
RpcResponse rpcResponse = ExtensionLoader.getExtensionLoader(SerializationService.class)
.getExtension(SerializationTypeEnum.getName(rpcMessage.getSerializeMethodCodec()))
.deserialize(bodyBytes, RpcResponse.class);
rpcMessage.setData(rpcResponse);
}
return rpcMessage;

}

private void checkVersion(ByteBuf byteBuf) {
byte version = byteBuf.readByte();
if (version != RpcConstants.VERSION) {
throw new IllegalArgumentException("version is not compatible: " + version);
}
}

private void checkMagicCode(ByteBuf byteBuf) {
int length = RpcConstants.MAGIC_NUMBER.length;
byte[] magicNumber = new byte[length];
byteBuf.readBytes(magicNumber);
for (int i = 0; i < length; i++) {
if (magicNumber[i] != RpcConstants.MAGIC_NUMBER[i]) {
throw new IllegalArgumentException("Unknown magic code: " + new String(magicNumber));
}
}
}
}
编码

编码的过程相对简单,就是根据协议,依次将对应位的数据写入即可

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
public class RpcMessageEncoder extends MessageToByteEncoder<RpcData> {

private static final AtomicInteger ATOMIC_INTEGER = new AtomicInteger(0);

@Override
protected void encode(ChannelHandlerContext channelHandlerContext, RpcData rpcData, ByteBuf byteBuf) {
try {
//encode head,marked full length index
int fullLengthIndex = encodeHead(rpcData,byteBuf);
// encode body
int fullLength = encodeBody(rpcData, byteBuf);
// back fill full length
encodeLength(fullLengthIndex,fullLength,byteBuf);
} catch (Exception e) {
LogUtil.error("Encode request error:{},data:{}", e, rpcData);
throw new RpcException(RpcErrorMessageEnum.REQUEST_ENCODE_FAIL.getCode(),
RpcErrorMessageEnum.REQUEST_ENCODE_FAIL.getMessage());
}

}
private int encodeHead(RpcData rpcData,ByteBuf byteBuf){
// write magic code and version 0-5
byteBuf.writeBytes(RpcConstants.MAGIC_NUMBER);
byteBuf.writeByte(RpcConstants.VERSION);
// marked full length index.
int fullLengthIndex = byteBuf.writerIndex();
// write placeholder for full length 9+
byteBuf.writerIndex(byteBuf.writerIndex() + 4);
// write message type
byteBuf.writeByte(rpcData.getMessageType());
// write codec
byteBuf.writeByte(rpcData.getSerializeMethodCodec());
// write compress
byteBuf.writeByte(rpcData.getCompressType());
// write requestId
byteBuf.writeInt(ATOMIC_INTEGER.getAndIncrement());
return fullLengthIndex;
}

private int encodeBody(RpcData rpcData,ByteBuf byteBuf){
byte[] bodyBytes = null;
int fullLength = RpcConstants.HEAD_LENGTH;
if (rpcData.canSendRequest()) {
LogUtil.info("serialize request start");
bodyBytes = ExtensionLoader.getExtensionLoader(SerializationService.class)
.getExtension(SerializationTypeEnum.getName(rpcData.getSerializeMethodCodec()))
.serialize(rpcData.getData());
LogUtil.info("serialize request end");

String compressName = CompressTypeEnum.getName(rpcData.getCompressType());
CompressService extension =
ExtensionLoader.getExtensionLoader(CompressService.class).getExtension(compressName);
bodyBytes = extension.compress(bodyBytes);
fullLength += bodyBytes.length;
}
if (bodyBytes != null) {
byteBuf.writeBytes(bodyBytes);
}
return fullLength;
}

private void encodeLength(int fullLengthIndex,int fullLength,ByteBuf byteBuf){
int writeIndex = byteBuf.writerIndex();
byteBuf.writerIndex(fullLengthIndex);
byteBuf.writeInt(fullLength);
byteBuf.writerIndex(writeIndex);
}
}

请求处理和调用

这里实用netty的SimpleChannelInboundHandler,可以避免资源释放的问题

由于前面已经实现了解码,所以只需要针对不同的请求类型进行不同的处理即可。

如果是心跳请求,则返回心跳响应

如果是服务请求,则通过动态代理调用服务,并写入结果返回给消费者。

定义一个响应类

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
@AllArgsConstructor
@NoArgsConstructor
@Data
@Builder
public class RpcResponse<T> implements Serializable {

private static final long serialVersionUID = 347966260947189201L;
/**
* request id
*/
private String requestId;
/**
* response code
*/
private Integer code;
/**
* response message
*/
private String message;
/**
* response body
*/
private T data;

/**
* success
* @param data
* @param requestId
* @return
* @param <T>
*/
public static <T> RpcResponse<T> success(T data, String requestId) {
RpcResponse<T> response = new RpcResponse<>();
response.setCode(RpcResponseCodeEnum.SUCCESS.getCode());
response.setMessage(RpcResponseCodeEnum.SUCCESS.getMessage());
response.setRequestId(requestId);
if (null != data) {
response.setData(data);
}
return response;
}

/**
* fail
* @return
* @param <T>
*/
public static <T> RpcResponse<T> fail() {
RpcResponse<T> response = new RpcResponse<>();
response.setCode(RpcResponseCodeEnum.FAIL.getCode());
response.setMessage(RpcResponseCodeEnum.FAIL.getMessage());
return response;
}

}

serverhandler的核心方法为channelRead0

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
public class NettyRpcServerHandler extends SimpleChannelInboundHandler<RpcData> {
/**
* Read the message transmitted by the server
*/

private final RpcRequestHandler rpcRequestHandler;

public NettyRpcServerHandler() {
this.rpcRequestHandler = SingletonFactory.getInstance(RpcRequestHandler.class);
}

/**
* heart beat handle
*
* @param ctx
* @param evt
* @throws Exception
*/
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
// if the channel is free,close it
if (evt instanceof IdleStateEvent) {
IdleState state = ((IdleStateEvent)evt).state();
if (state == IdleState.READER_IDLE) {
LogUtil.info("idle check happen, so close the connection");
ctx.close();
}
} else {
super.userEventTriggered(ctx, evt);
}
}

/**
* Called when an exception occurs in processing a client message
*/
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
LogUtil.error("server exceptionCaught");
cause.printStackTrace();
ctx.close();
}

@Override
protected void channelRead0(ChannelHandlerContext ctx, RpcData rpcData) throws Exception {
LogUtil.info("Server receive message: [{}]", rpcData);
RpcData rpcMessage = new RpcData();
setupRpcMessage(rpcMessage);

if (rpcData.isHeatBeatRequest()) {
handleHeartbeat(rpcMessage);
} else {
handleRpcRequest(ctx, rpcData, rpcMessage);
}
ctx.writeAndFlush(rpcMessage).addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
}

private void setupRpcMessage(RpcData rpcMessage) {
rpcMessage.setSerializeMethodCodec(SerializationTypeEnum.HESSIAN.getCode());
rpcMessage.setCompressType(CompressTypeEnum.GZIP.getCode());
}

private void handleHeartbeat(RpcData rpcMessage) {
rpcMessage.setMessageType(RpcConstants.HEARTBEAT_RESPONSE_TYPE);
rpcMessage.setData(RpcConstants.PONG);
}

private void handleRpcRequest(ChannelHandlerContext ctx, RpcData rpcData, RpcData rpcMessage) throws Exception {
RpcRequest rpcRequest = (RpcRequest)rpcData.getData();

// invoke target method
Object result = rpcRequestHandler.handle(rpcRequest);
LogUtil.info("Server get result: {}", result);

rpcMessage.setMessageType(RpcConstants.RESPONSE_TYPE);
buildAndSetRpcResponse(ctx, rpcRequest, rpcMessage, result);
}

private void
buildAndSetRpcResponse(ChannelHandlerContext ctx, RpcRequest rpcRequest, RpcData rpcMessage, Object result) {
if (canBuildResponse(ctx)) {
// If the channel is active and writable, a successful RPC response is constructed
RpcResponse<Object> rpcResponse = RpcResponse.success(result, rpcRequest.getTraceId());
rpcMessage.setData(rpcResponse);
} else {
// Construct a failed RPC response if the channel is not writable
RpcResponse<Object> rpcResponse = RpcResponse.fail();
rpcMessage.setData(rpcResponse);
LogUtil.error("Not writable now, message dropped,message:{}", rpcRequest);
}
}

private boolean canBuildResponse(ChannelHandlerContext ctx) {
return ctx.channel().isActive() && ctx.channel().isWritable();
}
}

tipp: 注册到zk后缓存的服务,可以直接基于动态代理进行调用

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
public class RpcRequestHandler {

private final RpcServiceRegistryAdapter adapter;

public RpcRequestHandler() {
this.adapter = SingletonFactory.getInstance(RpcServiceRegistryAdapterImpl.class);
}

/**
* Processing rpcRequest: call the corresponding method, and then return the
* method
*/
public Object handle(RpcRequest request) {
Object service = adapter.getService(request.fetchRpcServiceName());
return invoke(request, service);
}

/**
* get method execution results
*
* @param rpcRequest client request
* @param service service object
* @return the result of the target method execution
*/
private Object invoke(RpcRequest rpcRequest, Object service) {
Object result;
try {
Method method = service.getClass().getMethod(rpcRequest.getMethodName(), rpcRequest.getParamTypes());
result = method.invoke(service, rpcRequest.getParameters());
LogUtil.info("service:[{}] successful invoke method:[{}]", rpcRequest.getServiceName(),
rpcRequest.getMethodName());
} catch (NoSuchMethodException | IllegalArgumentException | InvocationTargetException
| IllegalAccessException e) {
LogUtil.error("occur exception when invoke target method,error:{},RpcRequest:{}", e, rpcRequest);
throw new RpcException(RpcErrorMessageEnum.SERVICE_INVOCATION_FAILURE.getCode(), RpcErrorMessageEnum.SERVICE_INVOCATION_FAILURE.getMessage());
}
return result;
}
}

至此,一个服务端的代码就完成了

基于springBoot,手写一个简单的RPC框架(一)


Code:pjpjsocute/rpc-service: personal rcp attempt (github.com)
技术栈使用包括:springboot,Zookeeper,netty,java spi


RPC定义

远程过程调用(Remote Procedure Call)是一种通信机制,允许不同的服务之间通过网络进行通信和交互。

通过RPC,一个服务可以向另一个服务发起请求并获取响应,就像本地调用一样,而无需开发者手动处理底层的网络通信细节。RPC框架会封装底层的网络传输,并提供了远程服务接口的定义、序列化和反序列化数据等功能。

rpc与http辨析:

​ HTTP是一种用于传输超文本的应用层协议,它在客户端和服务器之间进行通信。它基于请求-响应模型,客户端发送HTTP请求到服务器,服务器处理请求并返回相应的HTTP响应。RPC更类似一种架构思想,RPC可以用HTTP实现,TCP实现。

RPC流程

一个简单的RPC架构如图所示:

image-20230522161808172

如何实现:

一个简单的RPC调用链路:

image-20230522163034434

基于netty和ZK的服务端实现

根据上图,首先我们需要实现服务注册。

服务注册:

目前RPC框架大都支持注解的方式进行注册,这里也使用相同的方式。

定义注册注解

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
@Documented
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.TYPE})
@Inherited
public @interface RpcProvider {

/**
* Service group, default value is empty string
*/
String project() default "default";

/**
* Service version, default value is 1.0
*
* @return
*/
String version() default "1.0";

/**
* Service group, default value is empty string
*/
String group() default "default";

}

@Documented
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.FIELD, ElementType.TYPE})
@Inherited
public @interface RpcConsumer {
/**
* Service project, default value is empty string
*/
String project() default "default";

/**
* Service version, default value is 1.0
*
* @return
*/
String version() default "1.0";

/**
* Service group, default value is empty string
*/
String group() default "default";
}

@Target({ElementType.TYPE, ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@Import(CustomBeanScannerRegistrar.class)
@Documented
public @interface SimpleRpcApplication {

String[] basePackage();
}

​ 该注解将定义服务版本,group(区分同名同项目的不同接口),项目名,用于服务暴露。

​ 同理,还需要一个注解用于消费;一个注解定义需要扫描的包

在启动时注册服务

首先,需要将带有@provider的注解注册

获取需要扫描的包,随后带有注解的bean进行注册进入spring即可

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
public class CustomBeanScannerRegistrar implements ImportBeanDefinitionRegistrar, ResourceLoaderAware {

private ResourceLoader resourceLoader;

private static final String API_SCAN_PARAM = "basePackage";

private static final String SPRING_BEAN_BASE_PACKAGE = "org.example.ray";

@Override
public void registerBeanDefinitions(AnnotationMetadata importingClassMetadata, BeanDefinitionRegistry registry) {
//get the scan annotation and the bean package to be scanned
String[] scanBasePackages = fetchScanBasePackage(importingClassMetadata);
LogUtil.info("scanning packages: [{}]", (Object) scanBasePackages);

// //scan the package and register the bean
// RpcBeanScanner rpcConsumerBeanScanner = new RpcBeanScanner(registry, RpcConsumer.class);
RpcBeanScanner rpcProviderBeanScanner = new RpcBeanScanner(registry, RpcProvider.class);
RpcBeanScanner springBeanScanner = new RpcBeanScanner(registry, Component.class);
if (resourceLoader != null) {
springBeanScanner.setResourceLoader(resourceLoader);
rpcProviderBeanScanner.setResourceLoader(resourceLoader);
}
int rpcServiceCount = rpcProviderBeanScanner.scan(scanBasePackages);
LogUtil.info("rpcServiceScanner扫描的数量 [{}]", rpcServiceCount);
LogUtil.info("scanning RpcConsumer annotated beans end");
}

@Override
public void setResourceLoader(ResourceLoader resourceLoader) {
this.resourceLoader = resourceLoader;
}

private String[] fetchScanBasePackage(AnnotationMetadata importingClassMetadata){
AnnotationAttributes annotationAttributes = AnnotationAttributes.fromMap(importingClassMetadata.getAnnotationAttributes(SimpleRpcApplication.class.getName()));
String[] scanBasePackages = new String[0];
if (annotationAttributes != null) {
scanBasePackages = annotationAttributes.getStringArray(API_SCAN_PARAM);
}
//user doesn't specify the package to scan,use the Application base package
if (scanBasePackages.length == 0) {
scanBasePackages = new String[]{((org.springframework.core.type.StandardAnnotationMetadata) importingClassMetadata).getIntrospectedClass().getPackage().getName()};
}
return scanBasePackages;
}

}
在bean初始化之前将服务和相关的配置进行注册,确保spring启动后服务已经注册成功
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
@Component
public class RpcBeanPostProcessor implements BeanPostProcessor {

private final RpcServiceRegistryAdapter adapter;

private final RpcSendingServiceAdapter sendingServiceAdapter;

public RpcBeanPostProcessor() {
this.adapter = SingletonFactory.getInstance(RpcServiceRegistryAdapterImpl.class);;
this.sendingServiceAdapter = ExtensionLoader.getExtensionLoader(RpcSendingServiceAdapter.class)
.getExtension(RpcRequestSendingEnum.NETTY.getName());
}

/**
* register service
*
* @param bean
* @param beanName
* @return
* @throws BeansException
*/
@Override
public Object postProcessBeforeInitialization(Object bean, String beanName) throws BeansException {
LogUtil.info("start process register service: {}", bean);
// register service
if (bean.getClass().isAnnotationPresent(RpcProvider.class)) {
RpcProvider annotation = bean.getClass().getAnnotation(RpcProvider.class);
// build rpc service config
RpcServiceConfig serviceConfig = RpcServiceConfig.builder()
.service(bean)
.project(annotation.project())
.version(annotation.version())
.group(annotation.group())
.build();
LogUtil.info("register service: {}", serviceConfig);
adapter.registryService(serviceConfig);
}
return bean;
}
}
实现服务注册的具体方法

注册一个服务,至少应该吧包括:服务提供者(ip),服务名,以及@RpcProvider 中的变量,所以,可以先定义一个RpcServiceConfig.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
@Data
@NoArgsConstructor
@AllArgsConstructor
@Builder
public class RpcServiceConfig {
/**
* service version
*/
private String version = "";

/**
* target service
*/
private Object service;

/**
* belong to which project
*/
private String project = "";

/**
* group
*/
private String group = "";

/**
* generate service name,use to distinguish different service,and * can be split to get the service name
* @return
*/
public String fetchRpcServiceName() {
return this.getProject() + "*" + this.getGroup() + "*" + this.getServiceName() + "*" + this.getVersion();
}

/**
* get the interface name
*
* @return
*/
public String getServiceName() {
return this.service.getClass().getInterfaces()[0].getCanonicalName();
}

}

提供2个方法,注册服务与根据服务名得到对应的bean

1
2
3
4
5
6
7
8
9
10
11
12
13
14
public interface RpcServiceRegistryAdapter {

/**
* @param rpcServiceConfig rpc service related attributes
*/
void registryService(RpcServiceConfig rpcServiceConfig);

/**
* @param rpcClassName rpc class name
* @return service object
*/
Object getService(String rpcClassName);

}

注册流程可以分为3步,生成地址->服务注册进入Zookeeper->注册进入缓存。这里使用一个ConcurrentHashMap来进行缓存服务(方法中最后调用了zookeeper的api进行注册,因为与RPC关联不大,所以略过,可以直接参考源码)。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
public class RpcServiceRegistryAdapterImpl implements RpcServiceRegistryAdapter {

/**
* cache map
*/
private final Map<String, Object> serviceMap = new ConcurrentHashMap<>();

@Override
public void registryService(RpcServiceConfig rpcServiceConfig) {
try {
// first get address and service
String hostAddress = InetAddress.getLocalHost().getHostAddress();
// add service to zk
LogUtil.info("add service to zk,service name{},host:{}", rpcServiceConfig.fetchRpcServiceName(),hostAddress);
registerServiceToZk(rpcServiceConfig.fetchRpcServiceName(),
new InetSocketAddress(hostAddress, PropertiesFileUtil.readPortFromProperties()));
// add service to map cache
registerServiceToMap(rpcServiceConfig);
} catch (UnknownHostException e) {
LogUtil.error("occur exception when getHostAddress", e);
throw new RuntimeException(e);
}

}

@Override
public Object getService(String rpcServiceName) {
Object service = serviceMap.get(rpcServiceName);
if (null == service) {
throw new RpcException(RpcErrorMessageEnum.SERVICE_CAN_NOT_BE_FOUND.getCode(),"service not found");
}
return service;
}

private void registerServiceToZk(String rpcServiceName, InetSocketAddress inetSocketAddress) {
String servicePath = CuratorClient.ZK_REGISTER_ROOT_PATH + "/" + rpcServiceName + inetSocketAddress.toString();
CuratorFramework zkClient = CuratorClient.getZkClient();
CuratorClient.createPersistentNode(zkClient, servicePath);
}

private void registerServiceToMap(RpcServiceConfig rpcServiceConfig) {
String rpcServiceName = rpcServiceConfig.fetchRpcServiceName();
if (serviceMap.containsKey(rpcServiceName)) {
return;
}
serviceMap.put(rpcServiceName, rpcServiceConfig.getService());
}
}

​ 到这里,一个服务的注册流程已经完成。

Q:

给你一个字符串 s,找到 s 中最长的回文子串。

如果字符串的反序与原始字符串相同,则该字符串称为回文字符串。

Input:

1
2
3
输入:s = "babad"
输出:"bab"
解释:"aba" 同样是符合题意的答案。

Solution:

如果是第一次做类似的题目,其实一下子是并不容易想到可以使用动态规划。

但是,通过一步步的思考和优化,使用动态规划其实是一种比较自然的想法。

首先:
根据题目的要求,一种最简单的做法应该是:
  1. 设计一个方法 judgeIsPalindrome ,基于双指针,我们可以得到一个 O(n) 的判断 String(i,j) 是否是回文串的方法。

  2. 枚举所有的长度 k,从大到小判断所有的 String(i, i+k-1),判断是否是回文串,如果是,则得到题目要求。

这个做法很简单,但是存在一个问题就是,需要枚举所有的 k,最坏的情况需要枚举 N 种,而对每种长度,最坏需要计算 N 次才能得到结果。所以整体复杂度很高,无法通过。需要寻找一种更优的做法。

根据以上的思路,一种想法就是并不需要枚举每一个 k,而是改为使用二分法搜索最大的 k,这样复杂度会变为 log N。有兴趣的可以自己尝试一下这个思路。
上面这个思路是有可能通过的,但是仍然可以优化。我们可以注意到,上面的做法还有一个问题就是,每次都需要重新判断新长度的子串是否是回文,而没有利用到之前计算过的结果。

假设我们对一个特定的长度 k_i,我们已经知道了 String 中所有长度为 k_i 的子串是否是回文串。那么当我们知道 String(i, i+k_i-1) 是否是回文的时候,对于 String(i-1, i+k_i) 其实这个结果也是可以在 O(1) 的时间内得到:

1
dp[i][j] = dp[i+1][j-1] ^ String[i] == String[j]

即只有 String(i, i+k_i-1) 是回文并且 String[i] == String[j],那么 String(i-1, i+k_i) 才能是回文串。

下面可以写出 DP 代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
public String longestPalindrome(String s) {
int len = s.length();
if (len < 2) {
return s;
}
boolean[][] dp = new boolean[len][len];
for (int i = 0; i < len; i++) {
dp[i][i] = true;
}
int maxLen = 1;
int start = 0;
for (int j = 1; j < len; j++) {
for (int i = j - 1; i >= 0; i--) {
if (s.charAt(i) == s.charAt(j)) {
if (j - i < 3) {
dp[i][j] = true;
} else {
dp[i][j] = dp[i + 1][j - 1];
}
} else {
dp[i][j] = false;
}

if (dp[i][j]) {
int curLen = j - i + 1;
if (curLen > maxLen) {
maxLen = curLen;
start = i;
}
}
}
}
return s.substring(start, start + maxLen);
}

树状数组与力扣中的应用

为什么会需要树状数组

思考以下问题

QA:
假设存在一个整数序列 input,例如 intput = [1,2,7,4,3],要求前 K 个数的和。

Solution:
一般我们会求一个前缀和数组 preSumArray,其中 preSumArray[i] 代表前 i 个数的和。
这样我们求前 N 个数的和只需要返回 preSumArray[N],时间复杂度为 O(1)。如果需要查询 K 次,则复杂度为 O(K)。

升级这个问题

QA:
假设存在一个整数序列 input,例如 intput = [1,2,7,4,3],现在在我们获取前 N 个数的和时,可能会先将 i 位置的数增加/减少 value

Solution:
一般我们会求一个前缀和数组 preSumArray,其中 preSumArray[i] 代表前 i 个数的和。
但是如果我们需要在第 i 位置插入一个数 x,在进行更新时需要更新 i 之后的所有 preSumArray
此时单次的更新时间为 O(N),K 次查询的复杂度为 O(KN)。
如果我们不使用 preSumArray,那么更新复杂度为 O(1),查询复杂度会变为 O(N)。

这时树状数组可以帮助我们快速解决这个问题


前置知识——二进制的应用

二进制有很多有趣的应用,这里介绍一个用法:

1
lowbit(x) = x & (-x)

这个式子的目的是 求出能整除 x 的最大 2 次幂,也就是 x 最右边的 1

例子:

  • 5 & -5 = 1
  • 10 & -10 = 2
  • 12 & -12 = 4

树状数组(Binary Indexed Tree, BIT)

定义

本质上它仍是一个数组,与 preSumArray 相似,存的依旧是和数组,但是它存放的是 i 位之前 (包括 i),lowbit(i) 个整数的和

WX20230515-152510@2x

1
2
3
4
5
6
7
8
B(1) = A(1);
B(2) = A(1)+A(2);
B(3) = A(3);
B(4) = A(1)+A(2)+A(3)+A(4);
B(5) = A(5);
B(6) = A(5)+A(6);
B(7) = A(7);
B(8) = A(1)+A(2)+A(3)+A(4)+A(5)+A(6)+A(7)+A(8);

tip: 树状数组的下标必须从 1 开始


使用

树状数组主要解决两个操作:求和更新

求和

例子:

1
2
getSum(7) = A(1)+...+A(7) = B(4)+B(6)+B(7)
getSum(6) = B(4)+B(6)

实现代码:

1
2
3
4
5
6
7
public int getSum(int x) {
int res = 0;
for(int i = x; i > 0; i -= lowbit(i)) {
res += bit[i];
}
return res;
}

递归形式:

1
2
3
4
5
6
public int getSum(int x) {
if(x <= 0) {
return 0;
}
return bit[x] + getSum(x - lowbit(x));
}

复杂度:O(logN)

如果要求 sum(i,j),只需要 getSum(j) - getSum(i-1)


更新

例子:update(6,7),即在位置 6 加上 7,需要更新 B(6)B(8)

实现代码:

1
2
3
4
5
public void update(int x, int value) {
for(int i = x; i < bit.length; i += lowbit(i)) {
bit[i] += value;
}
}

力扣中的应用

LeetCode-493

QA:
给定一个数组 nums ,如果 i < jnums[i] > 2*nums[j] 我们就将 (i, j) 称作一个 重要翻转对
返回给定数组中的重要翻转对的数量。

Input:

1
2
输入: [1,3,2,3,1]
输出: 2

Solution:

题目可以转换为求 在 j 元素左边比它 2 倍大的元素有几个,并求和。

  1. 将数组排序并离散化映射为 1-n 的有序序列。
  2. 统计每个数的出现次数。
  3. 求前缀和,得到映射后的个数。

Code:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class Solution {
class TrieArr {
long[] arr;
public TrieArr(int n) { arr = new long[n]; }
public int lowbit(int x) { return x & -x; }
public int getSum(int x) {
if(x <= 0) return 0;
return (int)(arr[x] + getSum(x - lowbit(x)));
}
public void update(int x, int c) {
for(int i = x; i < arr.length; i += lowbit(i)) arr[i] += c;
}
}

public int reversePairs(int[] nums) {
Map<Long,Integer> map = new HashMap<>();
TreeSet<Long> set = new TreeSet<>();
for(int i: nums) {
set.add((long)i);
set.add((long)i * 2);
}
int index = 1;
while(!set.isEmpty()) {
map.put(set.pollFirst(), index++);
}
TrieArr bit = new TrieArr(map.size()+1);
int ans = 0;
for(int i = 0; i < nums.length; i++) {
long target = (long)nums[i] * 2;
int l = map.get(target);
ans += bit.getSum(map.size()) - bit.getSum(l);
bit.update(map.get((long)nums[i]), 1);
}
return ans;
}
}

类似问题

  • LeetCode-307 等

Q:

给定数组 nums 和一个整数 k 。我们将给定的数组 nums 分成 最多 k 个相邻的非空子数组 。 分数 由每个子数组内的平均值的总和构成。

注意我们必须使用 nums 数组中的每一个数进行分组,并且分数不一定需要是整数。

返回我们所能得到的最大 分数 是多少。答案误差在 10-6 内被视为是正确的。

阅读全文 »

Q:

RandomizedCollection 是一种包含数字集合(可能是重复的)的数据结构。它应该支持插入和删除特定元素,以及删除随机元素。

实现 RandomizedCollection 类:

  • RandomizedCollection()初始化空的 RandomizedCollection 对象。
  • bool insert(int val) 将一个 val 项插入到集合中,即使该项已经存在。如果该项不存在,则返回 true ,否则返回 false
  • bool remove(int val) 如果存在,从集合中移除一个 val 项。如果该项存在,则返回 true ,否则返回 false 。注意,如果 val 在集合中出现多次,我们只删除其中一个。
  • int getRandom() 从当前的多个元素集合中返回一个随机元素。每个元素被返回的概率与集合中包含的相同值的数量 线性相关

您必须实现类的函数,使每个函数的 平均 时间复杂度为 O(1)

注意:生成测试用例时,只有在 RandomizedCollection至少有一项 时,才会调用 getRandom

阅读全文 »

Q:

给你一个整数数组 nums 和一个整数 k ,找出 nums 中和至少为 k最短非空子数组 ,并返回该子数组的长度。如果不存在这样的 子数组 ,返回 -1

子数组 是数组中 连续 的一部分。

阅读全文 »