支持head请求,修复手机端请求重写切换应用恢复原始的请求问题

This commit is contained in:
wanghongenpin
2023-12-17 15:36:28 +08:00
parent 9a558c5319
commit 9aba03cfcb
11 changed files with 119 additions and 56 deletions

View File

@@ -84,6 +84,8 @@ class Channel {
pipeline.listen(this);
}
String? get selectedProtocol => isSsl ? (_socket as SecureSocket).selectedProtocol : null;
///是否是ssl链接
bool get isSsl => _socket is SecureSocket;
@@ -228,10 +230,12 @@ class ChannelPipeline extends ChannelHandler<Uint8List> {
return;
}
var data = _decoder.decode(buffer);
HttpRequest? request = remoteChannel?.getAttribute(AttributeKeys.request);
var data = _decoder.decode(buffer, resolveBody: request?.method != HttpMethod.head);
if (data == null) {
return;
}
// print(String.fromCharCodes(buffer.buffer));
var length = buffer.length;
buffer.clear();
@@ -247,14 +251,14 @@ class ChannelPipeline extends ChannelHandler<Uint8List> {
if (data is HttpResponse) {
data.packageSize = length;
data.remoteAddress = '${channel.remoteAddress.host}:${channel.remotePort}';
data.request = remoteChannel?.getAttribute(AttributeKeys.request);
data.request?.response = data;
data.request = request;
request?.response = data;
}
//websocket协议
if (data is HttpResponse && data.isWebSocket && remoteChannel != null) {
data.request?.hostAndPort?.scheme = channel.isSsl ? HostAndPort.wssScheme : HostAndPort.wsScheme;
print("webSocket ${data.request?.hostAndPort}");
request?.hostAndPort?.scheme = channel.isSsl ? HostAndPort.wssScheme : HostAndPort.wsScheme;
logger.d("webSocket ${data.request?.hostAndPort}");
remoteChannel.write(data);
var rawCodec = RawCodec();
@@ -284,7 +288,7 @@ class ChannelPipeline extends ChannelHandler<Uint8List> {
class RawCodec extends Codec<Object> {
@override
Object? decode(ByteBuf data) {
Object? decode(ByteBuf data, {bool resolveBody = true}) {
return data.readBytes(data.readableBytes());
}

View File

@@ -113,7 +113,7 @@ class ByteBuf {
/// 解码
abstract interface class Decoder<T> {
/// 解码 如果返回null说明数据不完整
T? decode(ByteBuf byteBuf);
T? decode(ByteBuf byteBuf, {bool resolveBody = true});
}
/// 编码
@@ -139,7 +139,7 @@ abstract class HttpCodec<T extends HttpMessage> implements Codec<T> {
T createMessage(List<String> reqLine);
@override
T? decode(ByteBuf data) {
T? decode(ByteBuf data, {bool resolveBody = true}) {
//请求行
if (_state == State.readInitial) {
init();
@@ -156,10 +156,10 @@ abstract class HttpCodec<T extends HttpMessage> implements Codec<T> {
//请求体
if (_state == State.body) {
var result = bodyReader!.readBody(data.readBytes(data.readableBytes()));
if (result.isDone) {
var result = resolveBody ? bodyReader!.readBody(data.readBytes(data.readableBytes())) : null;
if (!resolveBody || result?.isDone == true) {
_state = State.done;
message.body = result.body;
message.body = result?.body;
}
}

View File

@@ -18,6 +18,7 @@ import 'dart:convert';
import 'package:network_proxy/network/host_port.dart';
import 'package:network_proxy/network/http/websocket.dart';
import 'package:network_proxy/network/util/logger.dart';
import 'package:network_proxy/utils/compress.dart';
import 'http_headers.dart';
@@ -272,7 +273,7 @@ enum HttpMethod {
try {
return HttpMethod.values.firstWhere((element) => element.name == name.toUpperCase());
} catch (error) {
print("$name :$error");
logger.e("HttpMethod error $name :$error");
rethrow;
}
}

View File

@@ -68,23 +68,9 @@ class WebSocketFrame {
///websocket 解码器
class WebSocketDecoder {
// Add a buffer to store incomplete data
final buffer = BytesBuilder();
WebSocketFrame? decode(Uint8List byteBuf) {
// Add the new data to the buffer
buffer.add(byteBuf);
// Try to parse a WebSocket frame from the buffer
var data = buffer.toBytes();
if (canParseWebSocketFrame(data)) {
var frame = _parseWebSocketFrame(data);
buffer.clear();
return frame;
}
return null;
var frame = _parseWebSocketFrame(byteBuf);
return frame;
}
bool canParseWebSocketFrame(Uint8List data) {

View File

@@ -98,10 +98,6 @@ class Network {
channel.putAttribute(AttributeKeys.domain, hostAndPort?.host);
Channel? remoteChannel = channel.getAttribute(channel.id);
if (remoteChannel != null) {
remoteChannel.secureSocket = await SecureSocket.secure(remoteChannel.socket,
host: hostAndPort?.host, onBadCertificate: (certificate) => true);
}
if (HostFilter.filter(hostAndPort?.host)) {
remoteChannel = remoteChannel ?? await HttpClients.startConnect(hostAndPort!, RelayHandler(channel));
@@ -110,6 +106,13 @@ class Network {
return;
}
if (remoteChannel != null && !remoteChannel.isSsl) {
// var supportProtocols = TLS.supportProtocols(data);
remoteChannel.secureSocket = await SecureSocket.secure(remoteChannel.socket,
host: hostAndPort?.host, onBadCertificate: (certificate) => true);
}
// var selectedProtocol = remoteChannel?.selectedProtocol;
//ssl自签证书
var certificate = await CertificateManager.getCertificateContext(hostAndPort!.host);
//服务端等待客户端ssl握手

View File

@@ -17,6 +17,62 @@
import 'dart:typed_data';
class TLS {
///从TLS Client Hello 获取支持的协议
static List<String>? supportProtocols(Uint8List data) {
try {
int sessionLength = data[43];
int pos = 44 + sessionLength;
if (data.length < pos + 2) return null;
int cipherSuitesLength = data.buffer.asByteData().getUint16(pos);
pos += 2 + cipherSuitesLength;
if (data.length < pos + 1) return null;
int compressionMethodsLength = data[pos];
pos += 1 + compressionMethodsLength;
if (data.length < pos + 2) return null;
int extensionsLength = data.buffer.asByteData().getUint16(pos);
pos += 2;
if (data.length < pos + extensionsLength) return null;
List<String> protocols = [];
int end = pos + extensionsLength;
while (pos + 4 <= end) {
int extensionType = data.buffer.asByteData().getUint16(pos);
int extensionLength = data.buffer.asByteData().getUint16(pos + 2);
pos += 4;
if (extensionType == 16 /* ALPN */) {
if (pos + 2 > end) return protocols;
int alpnExtensionLength = data.buffer.asByteData().getUint16(pos);
pos += 2;
if (pos + alpnExtensionLength > end) return protocols;
int alpnEnd = pos + alpnExtensionLength;
while (pos + 1 <= alpnEnd) {
int protocolLength = data[pos];
pos += 1;
if (pos + protocolLength > alpnEnd) return protocols;
String protocol = String.fromCharCodes(data.sublist(pos, pos + protocolLength));
protocols.add(protocol);
pos += protocolLength;
}
} else {
pos += extensionLength;
}
}
return protocols;
} catch (_) {
// Ignore errors, just return empty list
}
return null;
}
///判断是否是TLS Client Hello
static bool isTLSClientHello(Uint8List data) {
if (data.length < 43) return false;
@@ -70,7 +126,7 @@ class TLS {
}
}
} catch (_) {
// Ignore errors, just return null
// Ignore errors, just return null
}
return null;