/* * Copyright 2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ import 'dart:math'; import 'dart:typed_data'; import 'package:proxypin/network/channel/channel_context.dart'; import 'package:proxypin/network/http/codec.dart'; import 'package:proxypin/network/http/h2/setting.dart'; import 'package:proxypin/network/http/http.dart'; import 'package:proxypin/network/util/byte_buf.dart'; import 'package:proxypin/network/util/logger.dart'; import 'frame.dart'; import 'hpack/hpack.dart'; /// http编解码 abstract class Http2Codec implements Codec { static const maxFrameSize = 16384; static final List connectionPrefacePRI = "PRI * HTTP/2.0".codeUnits; HPackDecoder decoder = HPackDecoder(); HPackEncoder encoder = HPackEncoder(); T createMessage(ChannelContext channelContext, FrameHeader frameHeader, Map> headers); T? getMessage(ChannelContext channelContext, FrameHeader frameHeader); @override DecoderResult decode(ChannelContext channelContext, ByteBuf byteBuf, {bool resolveBody = true}) { //Connection Preface PRI * HTTP/2.0 if (byteBuf.get(byteBuf.readerIndex) == 0x50 && byteBuf.get(byteBuf.readerIndex + 1) == 0x52 && byteBuf.get(byteBuf.readerIndex + 2) == 0x49 && isConnectionPrefacePRI(byteBuf)) { DecoderResult result = DecoderResult(); result.forward = byteBuf.readAvailableBytes(); return result; } while (byteBuf.isReadable()) { DecoderResult result = DecoderResult(isDone: false); FrameHeader? frameHeader = FrameReader._readFrameHeader(byteBuf); // logger.d("${frameHeader?.streamIdentifier} frame ${frameHeader?.length} ${byteBuf.readableBytes()}"); if (frameHeader == null) { return result; } List? framePayload = FrameReader._readFramePayload(byteBuf, frameHeader.length); if (framePayload == null) { byteBuf.readerIndex -= FrameReader.headerLength; return result; } result = parseHttp2Packet(channelContext, frameHeader, framePayload); if (result.isDone) { return result; } } return DecoderResult(isDone: false); } DecoderResult parseHttp2Packet(ChannelContext channelContext, FrameHeader frameHeader, List framePayload) { var result = DecoderResult(); // logger.d( // "${this is Http2RequestDecoder ? 'request' : 'response'} streamId: ${frameHeader.streamIdentifier} ${frameHeader.type} endHeaders: ${frameHeader.hasEndHeadersFlag} " // "endStream: ${frameHeader.hasEndStreamFlag} ${frameHeader.length}"); //根据帧类型进行处理 switch (frameHeader.type) { case FrameType.headers: //处理HEADERS帧 _handleHeadersFrame(channelContext, frameHeader, ByteBuf(framePayload)); result.isDone = frameHeader.hasEndStreamFlag && frameHeader.hasEndHeadersFlag; break; case FrameType.continuation: //处理CONTINUATION帧 var message = getMessage(channelContext, frameHeader); if (message == null) { logger.e("CONTINUATION frame but no message found"); result.forward = List.from(frameHeader.encode())..addAll(framePayload); return result; } Map> headers = _parseHeaders(channelContext, framePayload); headers.forEach((key, values) => message.headers.addValues(key, values)); if (frameHeader.hasEndHeadersFlag && channelContext.getStreamRequest(frameHeader.streamIdentifier)?.method == HttpMethod.head) { result.isDone = true; } break; case FrameType.data: //处理DATA帧 _handleDataFrame(channelContext, frameHeader, ByteBuf(framePayload)); result.isDone = frameHeader.hasEndStreamFlag; break; case FrameType.settings: SettingHandler.handleSettingsFrame(channelContext, frameHeader, ByteBuf(framePayload)); result.forward = List.from(frameHeader.encode())..addAll(framePayload); return result; default: //其他帧类型 原文转发 result.forward = List.from(frameHeader.encode())..addAll(framePayload); return result; } if (result.isDone && frameHeader.streamIdentifier > 0) { result.data = getMessage(channelContext, frameHeader); result.data?.streamId = frameHeader.streamIdentifier; channelContext.currentRequest = channelContext.getStreamRequest(frameHeader.streamIdentifier); if (result.data is HttpResponse) { channelContext.removeStream(frameHeader.streamIdentifier); } } return result; } List
encodeHeaders(T message); @override Uint8List encode(T data) { var bytesBuilder = BytesBuilder(); // data.headers.contentLength = data.body?.length ?? 0; var emptyBody = data.body == null || data.body!.isEmpty; //headers var headers = encodeHeaders(data); // BytesBuilder headerBlock = BytesBuilder(); bool firstFrame = true; var headerBlock = encoder.encode(headers); // for (var header in headers) { // //防止出现桢分片导致header分裂 // if (headerBlock.length + encode.length < maxFrameSize) { // headerBlock.add(encode); // continue; // } // // FrameType frameType = firstFrame ? FrameType.headers : FrameType.continuation; // int flags = frameType == FrameType.headers && emptyBody ? FrameHeader.flagsEndStream : 0; // firstFrame = false; // // _writeFrame(bytesBuilder, frameType, flags, data.streamId!, headerBlock.takeBytes()); // headerBlock.add(encode); // } FrameType frameType = firstFrame ? FrameType.headers : FrameType.continuation; int flags = frameType == FrameType.headers && emptyBody ? FrameHeader.flagsEndStream : 0; flags |= FrameHeader.flagsEndHeaders; _writeFrame(bytesBuilder, frameType, flags, data.streamId!, headerBlock); //body if (data.body?.isNotEmpty == true) { var payload = data.body!; while (payload.length > maxFrameSize) { var chunkSize = min(maxFrameSize, payload.length); var chunk = payload.sublist(0, chunkSize); payload = payload.sublist(chunkSize); _writeFrame(bytesBuilder, FrameType.data, 0, data.streamId!, chunk); } _writeFrame(bytesBuilder, FrameType.data, FrameHeader.flagsEndStream, data.streamId!, payload); } else if (frameType != FrameType.headers && emptyBody) { //如果没有body,发送一个空的DATA帧 _writeFrame(bytesBuilder, FrameType.data, FrameHeader.flagsEndStream, data.streamId!, []); } return bytesBuilder.takeBytes(); } void _writeFrame(BytesBuilder bytesBuilder, FrameType type, int flag, int streamId, List payload) { FrameHeader frameHeader = FrameHeader(payload.length, type, flag, streamId); // logger.d( // "${this is Http2RequestDecoder ? 'request' : 'response'} _writeFrame streamId: ${frameHeader.streamIdentifier} ${frameHeader.type} flags:${frameHeader.flags} endHeaders: ${frameHeader.hasEndHeadersFlag} endStream: ${frameHeader.hasEndStreamFlag} ${payload.length}"); bytesBuilder.add(frameHeader.encode()); bytesBuilder.add(payload); } bool isConnectionPrefacePRI(ByteBuf data) { if (data.readableBytes() < 9) { return false; } for (int i = 0; i < connectionPrefacePRI.length; i++) { if (data.get(data.readerIndex + i) != connectionPrefacePRI[i]) { return false; } } return true; } DataFrame _handleDataFrame(ChannelContext channelContext, FrameHeader frameHeader, ByteBuf payload) { // DATA 帧格式 int padLength = 0; //如果帧头部有PADDED标志位,则需要读取PADDED长度 if (frameHeader.hasPaddedFlag) { padLength = payload.readByte(); } payload.skipBytes(padLength); //读取数据 int dataLength = payload.readableBytes(); var data = payload.readBytes(dataLength); var message = getMessage(channelContext, frameHeader)!; if (message.body == null) { message.body = data; } else { message.body = List.from(message.body!)..addAll(data); } // print("DataFrame ${message.bodyAsString}"); return DataFrame(frameHeader, padLength, data); } HeadersFrame _handleHeadersFrame(ChannelContext channelContext, FrameHeader frameHeader, ByteBuf payload) { // HEADERS 帧格式 int padLength = 0; //如果帧头部有PADDED标志位,则需要读取PADDED长度 if (frameHeader.hasPaddedFlag) { padLength = payload.readByte(); } int? streamDependency; bool exclusiveDependency = false; int? weight; //如果帧头部有PRIORITY标志位,则需要读取优先级信息 if (frameHeader.hasPriorityFlag) { //读取优先级信息 int dependency = payload.readInt(); exclusiveDependency = (dependency & 0x80000000) == 0x80000000; streamDependency = dependency & 0x7fffffff; weight = payload.readByte(); // weight } payload.skipBytes(padLength); var headerBlockLength = payload.length - payload.readerIndex; if (headerBlockLength < 0) { throw Exception("headerBlockLength < 0"); } var blockFragment = payload.readBytes(headerBlockLength); //读取头部信息 Map> headers = _parseHeaders(channelContext, blockFragment); T message = createMessage(channelContext, frameHeader, headers); headers.forEach((key, values) { if (!key.startsWith(":")) { message.headers.addValues(key, values); } }); return HeadersFrame(frameHeader, padLength, exclusiveDependency, streamDependency, weight, blockFragment); } Map> _parseHeaders(ChannelContext channelContext, List payload) { if (channelContext.setting != null) { decoder.updateMaxReceivingHeaderTableSize(channelContext.setting!.headTableSize); } // Decode the headers List
headers = decoder.decode(payload); // Convert the headers to a map Map> headerMap = {}; for (Header header in headers) { final name = header.nameString; final value = header.valueString; headerMap[name] ??= []; headerMap[name]!.add(value); } return headerMap; } } class Http2RequestDecoder extends Http2Codec { @override HttpRequest createMessage(ChannelContext channelContext, FrameHeader frameHeader, Map> headers) { HttpMethod httpMethod = HttpMethod.valueOf(headers[":method"]!.first); var httpRequest = HttpRequest(httpMethod, headers[":path"]!.first, protocolVersion: headers[":version"]?.firstOrNull ?? "HTTP/2"); var old = channelContext.putStreamRequest(frameHeader.streamIdentifier, httpRequest); assert(old == null, "old request is not null"); return httpRequest; } @override HttpRequest? getMessage(ChannelContext channelContext, FrameHeader frameHeader) { return channelContext.getStreamRequest(frameHeader.streamIdentifier); } @override List
encodeHeaders(HttpRequest message) { var headers =
[]; var uri = message.requestUri!; headers.add(Header.ascii(":method", message.method.name)); headers.add(Header.ascii(":scheme", uri.scheme)); headers.add(Header.ascii(":authority", uri.host)); headers.add(Header.ascii(":path", message.uri)); message.headers.forEach((key, values) { for (var value in values) { headers.add(Header.ascii(key, value)); } }); return headers; } } class Http2ResponseDecoder extends Http2Codec { @override HttpResponse createMessage( ChannelContext channelContext, FrameHeader frameHeader, Map> headers) { var httpResponse = HttpResponse(HttpStatus.valueOf(int.parse(headers[':status']!.first)), protocolVersion: headers[":version"]?.firstOrNull ?? 'HTTP/2'); httpResponse.requestId = channelContext.getStreamRequest(frameHeader.streamIdentifier)!.requestId; channelContext.putStreamResponse(frameHeader.streamIdentifier, httpResponse); return httpResponse; } @override HttpResponse? getMessage(ChannelContext channelContext, FrameHeader frameHeader) { return channelContext.getStreamResponse(frameHeader.streamIdentifier); } @override List
encodeHeaders(HttpResponse message) { var headers =
[]; headers.add(Header.ascii(":status", message.status.code.toString())); message.headers.forEach((key, values) { for (var value in values) { headers.add(Header.ascii(key, value)); } }); return headers; } } class FrameReader { static int headerLength = 9; static List? _readFramePayload(ByteBuf data, int length) { if (data.readableBytes() < length) { return null; } var readBytes = data.readBytes(length); data.clearRead(); return readBytes; } static FrameHeader? _readFrameHeader(ByteBuf data) { if (data.readableBytes() < headerLength) { return null; } int length = data.read() << 16 | data.read() << 8 | data.read(); FrameType type = FrameType.values[data.read()]; int flags = data.read(); int streamIdentifier = data.readInt(); return FrameHeader(length, type, flags, streamIdentifier); } }