From c94ff7eba96eeecd6f4868afc1fdd1049ba68310 Mon Sep 17 00:00:00 2001 From: wanghongenpin Date: Mon, 12 May 2025 19:20:07 +0800 Subject: [PATCH] http2 streamDependency (#388)(#51) --- lib/network/channel/channel.dart | 12 +++--- lib/network/channel/channel_context.dart | 15 +++++-- lib/network/channel/channel_dispatcher.dart | 10 +++-- lib/network/handle/http_proxy_handle.dart | 22 +++++------ lib/network/handle/relay_handle.dart | 2 +- lib/network/handle/websocket_handle.dart | 4 +- lib/network/http/codec.dart | 14 +++---- lib/network/http/h2/frame.dart | 5 ++- lib/network/http/h2/h2_codec.dart | 44 ++++++++++++++++----- lib/network/http/http_client.dart | 11 +++--- lib/network/util/proxy_helper.dart | 12 +++--- lib/network/util/system_proxy.dart | 2 +- test/web_test.dart | 11 +++--- 13 files changed, 104 insertions(+), 60 deletions(-) diff --git a/lib/network/channel/channel.dart b/lib/network/channel/channel.dart index a1777f9..35f5f89 100644 --- a/lib/network/channel/channel.dart +++ b/lib/network/channel/channel.dart @@ -109,8 +109,8 @@ class Channel { ///是否是ssl链接 bool get isSsl => _socket is SecureSocket; - Future write(Object obj) async { - var data = dispatcher.encoder.encode(obj); + Future write(ChannelContext channelContext, Object obj) async { + var data = dispatcher.encoder.encode(channelContext, obj); await writeBytes(data); } @@ -135,7 +135,7 @@ class Channel { if (!isClosed) { _socket.add(bytes); } - await _socket.flush(); + // await _socket.flush(); } catch (e, t) { if (e is StateError && e.message == "StreamSink is closed") { isOpen = false; @@ -148,8 +148,8 @@ class Channel { } ///写入并关闭此channel - Future writeAndClose(Object obj) async { - await write(obj); + Future writeAndClose(ChannelContext channelContext, Object obj) async { + await write(channelContext, obj); close(); } @@ -165,6 +165,8 @@ class Channel { await Future.delayed(const Duration(milliseconds: 150)); } isOpen = false; + await _socket.flush(); + await _socket.close(); _socket.destroy(); } diff --git a/lib/network/channel/channel_context.dart b/lib/network/channel/channel_context.dart index 1683966..a092ac4 100644 --- a/lib/network/channel/channel_context.dart +++ b/lib/network/channel/channel_context.dart @@ -1,7 +1,7 @@ - import 'package:proxypin/network/channel/channel.dart'; import 'package:proxypin/network/channel/host_port.dart'; import 'package:proxypin/network/http/codec.dart'; +import 'package:proxypin/network/http/h2/frame.dart'; import 'package:proxypin/network/http/h2/setting.dart'; import 'package:proxypin/network/http/http.dart'; import 'package:proxypin/network/util/attribute_keys.dart'; @@ -25,6 +25,7 @@ class ChannelContext { //http2 stream final Map> _streams = {}; + final Map _streamDependency = {}; ChannelContext(); @@ -82,7 +83,7 @@ class ChannelContext { void putStreamResponse(int streamId, HttpResponse response) { var pair = _streams[streamId]; if (pair == null) { - pair = Pair(null, response); + pair = Pair(null, response); _streams[streamId] = pair; } @@ -102,4 +103,12 @@ class ChannelContext { void removeStream(int streamId) { _streams.remove(streamId); } -} \ No newline at end of file + + void put(int streamId, HeadersFrame frame) { + _streamDependency[streamId] = frame; + } + + HeadersFrame? removeStreamDependency(int streamId) { + return _streamDependency.remove(streamId); + } +} diff --git a/lib/network/channel/channel_dispatcher.dart b/lib/network/channel/channel_dispatcher.dart index 018c824..3b5063f 100644 --- a/lib/network/channel/channel_dispatcher.dart +++ b/lib/network/channel/channel_dispatcher.dart @@ -52,7 +52,7 @@ class ChannelDispatcher extends ChannelHandler { if (clientChannel.isSsl && !remoteChannel.isSsl) { //代理认证 if (proxyInfo?.isAuthenticated == true) { - await HttpClients.connectRequest(remote, remoteChannel, proxyInfo: proxyInfo); + await HttpClients.connectRequest(channelContext, remote, remoteChannel, proxyInfo: proxyInfo); } await remoteChannel.secureSocket(channelContext, host: channelContext.getAttribute(AttributeKeys.domain)); @@ -152,6 +152,10 @@ class ChannelDispatcher extends ChannelHandler { handler.channelRead(channelContext, channel, data!); } catch (error, trace) { + logger.e( + "[${channelContext.clientChannel?.id}] channelRead error isSsl:${channel.isSsl} ${channelContext.clientChannel?.selectedProtocol} ${channelContext.serverChannel?.selectedProtocol} ${String.fromCharCodes(buffer.bytes)}", + error: error, + stackTrace: trace); buffer.clear(); exceptionCaught(channelContext, channel, error, trace: trace); } @@ -167,7 +171,7 @@ class ChannelDispatcher extends ChannelHandler { channelContext.currentRequest?.hostAndPort = channelContext.host; logger.d("webSocket ${data.request?.hostAndPort}"); - remoteChannel.write(data); + remoteChannel.write(channelContext, data); channelContext.listener?.onResponse(channelContext, data); @@ -208,7 +212,7 @@ class RawCodec extends Codec> { } @override - List encode(dynamic data) { + List encode(ChannelContext channelContext, dynamic data) { return data as List; } } diff --git a/lib/network/handle/http_proxy_handle.dart b/lib/network/handle/http_proxy_handle.dart index 64b0f68..dc3f5ed 100644 --- a/lib/network/handle/http_proxy_handle.dart +++ b/lib/network/handle/http_proxy_handle.dart @@ -27,12 +27,12 @@ class HttpProxyChannelHandler extends ChannelHandler { void channelRead(ChannelContext channelContext, Channel channel, HttpRequest msg) async { //下载证书 if (msg.uri == 'http://proxy.pin/ssl' || msg.requestUrl == 'http://127.0.0.1:${channel.socket.port}/ssl') { - ProxyHelper.crtDownload(channel, msg); + ProxyHelper.crtDownload(channelContext, channel, msg); return; } //请求本服务 if ((await localIps()).contains(msg.hostAndPort?.host) && msg.hostAndPort?.port == channel.socket.port) { - ProxyHelper.localRequest(msg, channel); + ProxyHelper.localRequest(channelContext, msg, channel); return; } @@ -77,7 +77,7 @@ class HttpProxyChannelHandler extends ChannelHandler { if (httpRequest.method == HttpMethod.connect) { channel.error = error; //记录异常 //https代理新建connect连接请求 返回ok 会继续发起正常请求 可以获取到请求内容 - await channel.write( + await channel.write(channelContext, HttpResponse(HttpStatus.ok.reason('Connection established'), protocolVersion: httpRequest.protocolVersion)); } else { rethrow; @@ -89,7 +89,7 @@ class HttpProxyChannelHandler extends ChannelHandler { if (httpRequest.method != HttpMethod.connect) { log.d("[${channel.id}] ${httpRequest.protocolVersion} ${httpRequest.method.name} ${httpRequest.requestUrl}"); if (HostFilter.filter(httpRequest.hostAndPort?.host)) { - await remoteChannel.write(httpRequest); + await remoteChannel.write(channelContext, httpRequest); return; } @@ -115,7 +115,7 @@ class HttpProxyChannelHandler extends ChannelHandler { await redirect(channelContext, channel, request, redirectUrl!); return; } - await remoteChannel.write(request); + await remoteChannel.write(channelContext, request); } } @@ -135,7 +135,7 @@ class HttpProxyChannelHandler extends ChannelHandler { httpRequest.headers.host = '${redirectUri.host}${redirectUri.hasPort ? ':${redirectUri.port}' : ''}'; var redirectChannel = await HttpClients.connect(redirectUri, proxyHandler, channelContext); channelContext.serverChannel = redirectChannel; - await redirectChannel.write(httpRequest); + await redirectChannel.write(channelContext, httpRequest); } /// 获取远程连接 @@ -167,10 +167,10 @@ class HttpProxyChannelHandler extends ChannelHandler { httpRequest.headers.set(HttpHeaders.PROXY_AUTHORIZATION, 'Basic $auth'); } - await proxyChannel.write(httpRequest); + await proxyChannel.write(channelContext, httpRequest); } else { if (clientChannel.isSsl) { - await HttpClients.connectRequest(hostAndPort, proxyChannel, proxyInfo: proxyInfo); + await HttpClients.connectRequest(channelContext, hostAndPort, proxyChannel, proxyInfo: proxyInfo); await proxyChannel.secureSocket(channelContext, host: hostAndPort.host, supportedProtocols: httpRequest.protocolVersion == "HTTP/2" ? ["h2"] : null); } @@ -201,7 +201,7 @@ class HttpProxyChannelHandler extends ChannelHandler { //https代理新建连接请求 if (httpRequest.method == HttpMethod.connect) { - await clientChannel.write( + await clientChannel.write(channelContext, HttpResponse(HttpStatus.ok.reason('Connection established'), protocolVersion: httpRequest.protocolVersion)); } return proxyChannel; @@ -232,7 +232,7 @@ class HttpResponseProxyHandler extends ChannelHandler { //域名是否过滤 if (HostFilter.filter(request?.hostAndPort?.host) || request?.method == HttpMethod.connect) { - await clientChannel.write(msg); + await clientChannel.write(channelContext, msg); return; } @@ -251,7 +251,7 @@ class HttpResponseProxyHandler extends ChannelHandler { listener?.onResponse(channelContext, response!); //发送给客户端 - await clientChannel.write(response!); + await clientChannel.write(channelContext, response!); } @override diff --git a/lib/network/handle/relay_handle.dart b/lib/network/handle/relay_handle.dart index 21fba72..b042a56 100644 --- a/lib/network/handle/relay_handle.dart +++ b/lib/network/handle/relay_handle.dart @@ -9,7 +9,7 @@ class RelayHandler extends ChannelHandler { @override void channelRead(ChannelContext channelContext, Channel channel, Object msg) async { //发送给客户端 - remoteChannel.write(msg); + remoteChannel.write(channelContext, msg); } @override diff --git a/lib/network/handle/websocket_handle.dart b/lib/network/handle/websocket_handle.dart index f23de7f..a887cc9 100644 --- a/lib/network/handle/websocket_handle.dart +++ b/lib/network/handle/websocket_handle.dart @@ -17,7 +17,7 @@ class WebSocketChannelHandler extends ChannelHandler { @override void channelRead(ChannelContext channelContext, Channel channel, Uint8List msg) { - proxyChannel.write(msg); + proxyChannel.writeBytes(msg); WebSocketFrame? frame; try { frame = decoder.decode(msg); @@ -32,6 +32,6 @@ class WebSocketChannelHandler extends ChannelHandler { message.messages.add(frame); channelContext.listener?.onMessage(channel, message, frame); logger.d( - "[${channelContext.clientChannel?.id}] socket channelRead ${frame.payloadLength} ${frame.fin} ${frame.payloadDataAsString}"); + "[${channelContext.clientChannel?.id}] websocket channelRead ${frame.payloadLength} ${frame.fin} ${frame.payloadDataAsString}"); } } diff --git a/lib/network/http/codec.dart b/lib/network/http/codec.dart index c8fd375..f0be141 100644 --- a/lib/network/http/codec.dart +++ b/lib/network/http/codec.dart @@ -66,7 +66,7 @@ abstract interface class Decoder { /// 编码 abstract interface class Encoder { - List encode(T data); + List encode(ChannelContext channelContext, T data); } /// 编解码器 @@ -154,9 +154,9 @@ abstract class HttpCodec implements Codec { void initialLine(BytesBuilder buffer, T message); @override - List encode(T message) { + List encode(ChannelContext channelContext, T message) { if (message.protocolVersion == "HTTP/2") { - return getH2Codec().encode(message); + return getH2Codec().encode(channelContext, message); } BytesBuilder builder = BytesBuilder(); @@ -278,8 +278,8 @@ class HttpServerCodec extends Codec { } @override - List encode(HttpResponse data) { - return responseCodec.encode(data); + List encode(ChannelContext channelContext, HttpResponse data) { + return responseCodec.encode(channelContext, data); } } @@ -293,7 +293,7 @@ class HttpClientCodec extends Codec { } @override - List encode(HttpRequest data) { - return requestCodec.encode(data); + List encode(ChannelContext channelContext, HttpRequest data) { + return requestCodec.encode(channelContext, data); } } diff --git a/lib/network/http/h2/frame.dart b/lib/network/http/h2/frame.dart index 8372fbf..3a17c54 100644 --- a/lib/network/http/h2/frame.dart +++ b/lib/network/http/h2/frame.dart @@ -19,6 +19,7 @@ enum FrameType { data, headers, priority, rstStream, settings, pushPromise, ping class FrameHeader { static const flagsEndStream = 0x01; static const flagsEndHeaders = 0x04; + static const flagsPriority = 0x20; final int length; final FrameType type; @@ -29,7 +30,7 @@ class FrameHeader { bool get hasPaddedFlag => (flags & 0x08) == 0x08; - bool get hasPriorityFlag => (flags & 0x20) == 0x20; + bool get hasPriorityFlag => (flags & flagsPriority) == flagsPriority; bool get hasEndHeadersFlag => (flags & flagsEndHeaders) == flagsEndHeaders; @@ -74,7 +75,7 @@ class HeadersFrame extends Frame { final bool exclusiveDependency; final int? streamDependency; final int? weight; - final List headerBlockFragment; + List headerBlockFragment; HeadersFrame(super.header, this.padLength, this.exclusiveDependency, this.streamDependency, this.weight, this.headerBlockFragment); diff --git a/lib/network/http/h2/h2_codec.dart b/lib/network/http/h2/h2_codec.dart index bccbe50..80dd7b2 100644 --- a/lib/network/http/h2/h2_codec.dart +++ b/lib/network/http/h2/h2_codec.dart @@ -108,8 +108,12 @@ abstract class Http2Codec implements Codec { switch (frameHeader.type) { case FrameType.headers: //处理HEADERS帧 - _handleHeadersFrame(channelContext, frameHeader, ByteBuf(framePayload)); + var headersFrame = _handleHeadersFrame(channelContext, frameHeader, ByteBuf(framePayload)); result.isDone = frameHeader.hasEndStreamFlag && frameHeader.hasEndHeadersFlag; + if (headersFrame.streamDependency != null) { + headersFrame.headerBlockFragment = []; + channelContext.put(frameHeader.streamIdentifier, headersFrame); + } break; case FrameType.continuation: //处理CONTINUATION帧 @@ -168,7 +172,7 @@ abstract class Http2Codec implements Codec { List
encodeHeaders(T message); @override - Uint8List encode(T data) { + Uint8List encode(ChannelContext channelContext, T data) { var bytesBuilder = BytesBuilder(); if (data.headers.getInt(HttpHeaders.CONTENT_LENGTH) != null) { data.headers.set(HttpHeaders.CONTENT_LENGTH.toLowerCase(), "${data.body?.length ?? 0}"); @@ -179,7 +183,7 @@ abstract class Http2Codec implements Codec { //headers var headers = encodeHeaders(data); - writeHeadersFrame(bytesBuilder, data.streamId!, headers, endStream: emptyBody); + writeHeadersFrame(bytesBuilder, channelContext, data.streamId!, headers, endStream: emptyBody); //body if (!emptyBody) { @@ -199,25 +203,26 @@ abstract class Http2Codec implements Codec { void writeHeadersFrame( BytesBuilder bytesBuilder, + ChannelContext channelContext, int streamId, List
headers, { StreamSetting? setting, bool endStream = true, }) { var fragment = _hpackEncoder.encode(headers); - var maxSize = setting?.maxFrameSize ?? maxFrameSize; + var maxSize = channelContext.setting?.maxFrameSize ?? maxFrameSize; if (fragment.length < maxSize) { int flags = FrameHeader.flagsEndHeaders; if (endStream) { flags |= FrameHeader.flagsEndStream; } - _writeFrame(bytesBuilder, FrameType.headers, flags, streamId, fragment); + _writeHeadersFrame(bytesBuilder, channelContext, flags, streamId, fragment); } else { var chunk = fragment.sublist(0, maxSize); fragment = fragment.sublist(maxSize); - _writeFrame(bytesBuilder, FrameType.headers, 0, streamId, chunk); + _writeHeadersFrame(bytesBuilder, channelContext, 0, streamId, chunk); while (fragment.length > maxSize) { var chunk = fragment.sublist(0, maxSize); @@ -234,8 +239,29 @@ abstract class Http2Codec implements Codec { } } - void _writeFrame(BytesBuilder bytesBuilder, FrameType type, int flag, int streamId, List payload) { - FrameHeader frameHeader = FrameHeader(payload.length, type, flag, streamId); + void _writeHeadersFrame( + BytesBuilder bytesBuilder, ChannelContext channelContext, int flags, int streamId, List payload) { + var streamPriority = channelContext.removeStreamDependency(streamId); + if (streamPriority != null) { + flags |= FrameHeader.flagsPriority; + bool exclusive = streamPriority.exclusiveDependency; + int streamDependency = streamPriority.streamDependency!; + + payload = [ + (exclusive ? 0x80 : 0) | (streamDependency & 0x7FFFFFFF) >> 24, + (streamDependency & 0x00FF0000) >> 16, + (streamDependency & 0x0000FF00) >> 8, + (streamDependency & 0x000000FF), + streamPriority.weight!, + ...payload + ]; + } + + _writeFrame(bytesBuilder, FrameType.headers, flags, streamId, payload); + } + + void _writeFrame(BytesBuilder bytesBuilder, FrameType type, int flags, int streamId, List payload) { + FrameHeader frameHeader = FrameHeader(payload.length, type, flags, streamId); // logger.d( // "${this is Http2RequestDecoder ? 'request' : 'response'} _writeFrame streamId: ${frameHeader.streamIdentifier} ${frameHeader.type} flags:${frameHeader.flags} endHeaders: ${frameHeader.hasEndHeadersFlag} endStream: ${frameHeader.hasEndStreamFlag} ${payload.length}"); @@ -300,7 +326,7 @@ abstract class Http2Codec implements Codec { weight = payload.readByte(); // 读取权重 logger.d( - "PRIORITY frame parsed: exclusive=$exclusiveDependency, streamDependency=$streamDependency, weight=$weight"); + "PRIORITY frame parsed: streamId:${frameHeader.streamIdentifier} padLength:$padLength exclusive=$exclusiveDependency, streamDependency=$streamDependency, weight=$weight"); } var headerBlockLength = payload.length - payload.readerIndex - padLength; diff --git a/lib/network/http/http_client.dart b/lib/network/http/http_client.dart index 3a8048a..bac1488 100644 --- a/lib/network/http/http_client.dart +++ b/lib/network/http/http_client.dart @@ -67,7 +67,7 @@ class HttpClients { var channel = await client.connect(connectHost, channelContext); if (proxyInfo != null) { - await connectRequest(hostAndPort, channel, proxyInfo: proxyInfo); + await connectRequest(channelContext, hostAndPort, channel, proxyInfo: proxyInfo); } if (hostAndPort.isSsl()) { @@ -88,7 +88,8 @@ class HttpClients { } ///发起代理连接请求 - static Future connectRequest(HostAndPort hostAndPort, Channel channel, {ProxyInfo? proxyInfo}) async { + static Future connectRequest(ChannelContext channelContext, HostAndPort hostAndPort, Channel channel, + {ProxyInfo? proxyInfo}) async { ChannelHandler handler = channel.dispatcher.handler; //代理 发送connect请求 var httpResponseHandler = HttpResponseHandler(); @@ -103,7 +104,7 @@ class HttpClients { proxyRequest.headers.set(HttpHeaders.PROXY_AUTHORIZATION, 'Basic $auth'); } - await channel.write(proxyRequest); + await channel.write(channelContext, proxyRequest); var response = await httpResponseHandler.getResponse(const Duration(seconds: 5)); channel.dispatcher.handler = handler; @@ -144,7 +145,7 @@ class HttpClients { ChannelContext channelContext = ChannelContext(); Channel channel = await client.connect(hostAndPort, channelContext); - await channel.write(request); + await channel.write(channelContext, request); return httpResponseHandler.getResponse(timeout).whenComplete(() => channel.close()); } @@ -175,7 +176,7 @@ class HttpClients { request.headers.remove(HttpHeaders.HOST); request.streamId = 1; } - await channel.write(request); + await channel.write(channelContext, request); return httpResponseHandler.getResponse(timeout).whenComplete(() => channel.close()); } } diff --git a/lib/network/util/proxy_helper.dart b/lib/network/util/proxy_helper.dart index 912e715..5766ff7 100644 --- a/lib/network/util/proxy_helper.dart +++ b/lib/network/util/proxy_helper.dart @@ -33,7 +33,7 @@ import '../components/host_filter.dart'; class ProxyHelper { //请求本服务 - static localRequest(HttpRequest msg, Channel channel) async { + static localRequest(ChannelContext channelContext, HttpRequest msg, Channel channel) async { //获取配置 if (msg.path == '/config') { final requestRewrites = await RequestRewriteManager.instance; @@ -50,7 +50,7 @@ class ProxyHelper { }), }; response.body = utf8.encode(json.encode(body)); - channel.writeAndClose(response); + channel.writeAndClose(channelContext, response); return; } @@ -58,11 +58,11 @@ class ProxyHelper { response.body = utf8.encode('pong'); response.headers.set("os", Platform.operatingSystem); response.headers.set("hostname", Platform.isAndroid ? Platform.operatingSystem : Platform.localHostname); - channel.writeAndClose(response); + channel.writeAndClose(channelContext, response); } /// 下载证书 - static void crtDownload(Channel channel, HttpRequest request) async { + static void crtDownload(ChannelContext channelContext, Channel channel, HttpRequest request) async { const String fileMimeType = 'application/x-x509-ca-cert'; var response = HttpResponse(HttpStatus.ok); response.headers.set(HttpHeaders.CONTENT_TYPE, fileMimeType); @@ -74,11 +74,11 @@ class ProxyHelper { response.headers.set("Content-Length", caBytes.lengthInBytes.toString()); if (request.method == HttpMethod.head) { - channel.writeAndClose(response); + channel.writeAndClose(channelContext, response); return; } response.body = caBytes; - channel.writeAndClose(response); + channel.writeAndClose(channelContext, response); } ///异常处理 diff --git a/lib/network/util/system_proxy.dart b/lib/network/util/system_proxy.dart index 2188b47..ea28b87 100644 --- a/lib/network/util/system_proxy.dart +++ b/lib/network/util/system_proxy.dart @@ -257,7 +257,7 @@ class WindowsSystemProxy extends SystemProxy { @override Future _setProxyPassDomains(String proxyPassDomains) async { var results = await _internetSettings('add', ['ProxyOverride', '/t', 'REG_SZ', '/d', proxyPassDomains, '/f']); - logger.e('set proxyPassDomains, stdout: $results'); + logger.i('set proxyPassDomains, stdout: $results'); } static Future _internetSettings(String cmd, List args) async { diff --git a/test/web_test.dart b/test/web_test.dart index a02378b..d7d6a23 100644 --- a/test/web_test.dart +++ b/test/web_test.dart @@ -1,6 +1,7 @@ import 'dart:async'; import 'dart:io'; +import 'package:proxypin/network/channel/channel_context.dart'; import 'package:proxypin/network/http/codec.dart'; import 'package:proxypin/network/http/http.dart'; @@ -29,13 +30,13 @@ socketTest() async { httpRequest.headers.set('user-agent', 'Dart/3.0 (dart:io)'); httpRequest.headers.set('accept-encoding', 'gzip'); httpRequest.headers.set(HttpHeaders.hostHeader, host); - + ChannelContext channelContext = ChannelContext(); var codec = HttpRequestCodec(); - print(String.fromCharCodes(codec.encode(httpRequest))); - socket.add(codec.encode(httpRequest)); + print(String.fromCharCodes(codec.encode(channelContext, httpRequest))); + socket.add(codec.encode(channelContext, httpRequest)); await socket.flush(); - // subscription.resume(); + // subscription.resume(); await completer.future; // await Future.delayed(const Duration(milliseconds: 1600)); @@ -55,7 +56,7 @@ socketTest() async { httpRequest = HttpRequest(HttpMethod.get, "/"); httpRequest.headers.set(HttpHeaders.hostHeader, host); - secureSocket.add(codec.encode(httpRequest)); + secureSocket.add(codec.encode(channelContext, httpRequest)); await secureSocket.flush(); await completer.future; }