diff --git a/android/app/src/main/kotlin/com/network/proxy/ProxyVpnService.kt b/android/app/src/main/kotlin/com/network/proxy/ProxyVpnService.kt index 67d814d..cad1246 100644 --- a/android/app/src/main/kotlin/com/network/proxy/ProxyVpnService.kt +++ b/android/app/src/main/kotlin/com/network/proxy/ProxyVpnService.kt @@ -12,6 +12,7 @@ import android.os.Build import android.os.ParcelFileDescriptor import android.util.Log import androidx.core.app.NotificationCompat +import com.network.proxy.vpn.ProxyVpnThread import com.network.proxy.vpn.socket.ProtectSocket import com.network.proxy.vpn.socket.ProtectSocketHolder @@ -21,6 +22,7 @@ import com.network.proxy.vpn.socket.ProtectSocketHolder */ class ProxyVpnService : VpnService(), ProtectSocket { private var vpnInterface: ParcelFileDescriptor? = null + private var vpnThread: ProxyVpnThread? = null companion object { const val MAX_PACKET_LEN = 1500 @@ -86,6 +88,7 @@ class ProxyVpnService : VpnService(), ProtectSocket { } private fun disconnect() { + vpnThread?.run { stopThread() } vpnInterface?.close() if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.N) { stopForeground(STOP_FOREGROUND_REMOVE) @@ -111,6 +114,12 @@ class ProxyVpnService : VpnService(), ProtectSocket { ProtectSocketHolder.setProtectSocket(this) showServiceNotification() + vpnThread = ProxyVpnThread( + vpnInterface!!, + proxyHost, + proxyPort + ) + vpnThread!!.start() isRunning = true } diff --git a/android/app/src/main/kotlin/com/network/proxy/vpn/Connection.kt b/android/app/src/main/kotlin/com/network/proxy/vpn/Connection.kt new file mode 100644 index 0000000..1211ccc --- /dev/null +++ b/android/app/src/main/kotlin/com/network/proxy/vpn/Connection.kt @@ -0,0 +1,179 @@ +package com.network.proxy.vpn + +import android.util.Log +import com.network.proxy.vpn.socket.CloseableConnection +import com.network.proxy.vpn.transport.protocol.IP4Header +import com.network.proxy.vpn.transport.protocol.TCPHeader +import com.network.proxy.vpn.transport.protocol.UDPHeader +import com.network.proxy.vpn.util.PacketUtil +import java.io.ByteArrayOutputStream +import java.io.IOException +import java.nio.ByteBuffer +import java.nio.channels.SelectionKey +import java.nio.channels.spi.AbstractSelectableChannel +import kotlin.concurrent.Volatile + +class Connection( + val protocol: Protocol, + val sourceIp: Int, val sourcePort: Int, + val destinationIp: Int, val destinationPort: Int, + private val connectionCloser: CloseableConnection +) { + + var channel: AbstractSelectableChannel? = null + var selectionKey: SelectionKey? = null + + //接收用于存储来自远程主机的数据的缓冲器 + private val receivingStream: ByteArrayOutputStream = ByteArrayOutputStream() + + //发送缓冲区,用于存储要从vpn客户端发送到目标主机的数据 + private val sendingStream: ByteArrayOutputStream = ByteArrayOutputStream() + + var hasReceivedLastSegment = false + + /** + * 是否初始化链接 针对代理判断协议延迟初始化 + */ + var isInitConnect = false + + //指示三向握手是否已完成 + var isConnected = false + + //从客户端接收的最后一个数据包 + var lastIpHeader: IP4Header? = null + var lastTcpHeader: TCPHeader? = null + var lastUdpHeader: UDPHeader? = null + + var timestampSender = 0 + var timestampReplyTo = 0 + + //从客户端接收的序列 + var recSequence: Long = 0 + + //在tcp选项内的SYN期间由客户端发送 + var maxSegmentSize = 0 + + //跟踪我们发送给客户端的ack,并等待客户端返回ack + var sendUnAck: Long = 0 + + //发送到客户端的下一个ack + var sendNext: Long = 0 + + //true when connection is about to be close + var isClosingConnection = false + + //指示客户端的数据已准备好发送到目标 + @Volatile + var isDataForSendingReady = false + + //closing session and aborting connection, will be done by background task + @Volatile + var isAbortingConnection = false + + //indicate that vpn client has sent FIN flag and it has been acked + var isAckedToFin = false + + companion object { + fun getConnectionKey( + protocol: Protocol, destIp: Int, destPort: Int, sourceIp: Int, sourcePort: Int + ): String { + return protocol.name + "|" + PacketUtil.intToIPAddress(sourceIp) + ":" + sourcePort + + "->" + PacketUtil.intToIPAddress(destIp) + ":" + destPort + } + } + +// fun getConnectionKey(): String { +// return getConnectionKey(protocol, destinationIp, destinationIp, sourceIp, sourcePort) +// } + + fun closeConnection() { + connectionCloser.closeConnection(this) + } + + /** + * 设置要发送到目标服务器的数据 + */ + @Synchronized + fun setSendingData(data: ByteBuffer): Int { + val remaining = data.remaining() + sendingStream.write(data.array(), data.position(), data.remaining()) + return remaining + } + + @Synchronized + fun addReceivedData(data: ByteArray?) { + try { + receivingStream.write(data) + } catch (e: IOException) { + Log.e(TAG, e.toString()) + } + } + + /** + * 获取缓冲区中接收到的所有数据并清空它。 + */ + @Synchronized + fun getReceivedData(maxSize: Int): ByteArray? { + var data = receivingStream.toByteArray() + receivingStream.reset() + if (data.size > maxSize) { + val small = ByteArray(maxSize) + System.arraycopy(data, 0, small, 0, maxSize) + val len = data.size - maxSize + receivingStream.write(data, maxSize, len) + data = small + } + return data + } + + /** + * buffer has more data for vpn client + */ + fun hasReceivedData(): Boolean { + return receivingStream.size() > 0 + } + + fun hasDataToSend(): Boolean { + return sendingStream.size() > 0 + } + + /** + * 出列数据以发送到服务器 + */ + @Synchronized + fun getSendingData(): ByteArray? { + val data = sendingStream.toByteArray() + sendingStream.reset() + return data + } + + fun cancelKey() { + selectionKey?.let { + synchronized(it) { + if (!it.isValid) return + it.cancel() + } + } + + } + + fun subscribeKey(op: Int) { + selectionKey?.let { + synchronized(it) { + if (!it.isValid) return + it.interestOps(it.interestOps() or op) + } + } + } + + fun unsubscribeKey(op: Int) { + selectionKey?.let { + synchronized(it) { + if (!it.isValid) return + it.interestOps(it.interestOps() and op.inv()) + } + } + } + + +} \ No newline at end of file diff --git a/android/app/src/main/kotlin/com/network/proxy/vpn/ConnectionHandler.kt b/android/app/src/main/kotlin/com/network/proxy/vpn/ConnectionHandler.kt new file mode 100644 index 0000000..9168662 --- /dev/null +++ b/android/app/src/main/kotlin/com/network/proxy/vpn/ConnectionHandler.kt @@ -0,0 +1,510 @@ +package com.network.proxy.vpn + +import android.os.Build +import android.util.Log +import com.network.proxy.vpn.Connection.Companion.getConnectionKey +import com.network.proxy.vpn.socket.ClientPacketWriter +import com.network.proxy.vpn.socket.SocketNIODataService +import com.network.proxy.vpn.transport.icmp.ICMPPacket +import com.network.proxy.vpn.transport.icmp.ICMPPacketFactory +import com.network.proxy.vpn.transport.protocol.IP4Header +import com.network.proxy.vpn.transport.protocol.IPPacketFactory +import com.network.proxy.vpn.transport.protocol.TCPHeader +import com.network.proxy.vpn.transport.protocol.TCPPacketFactory +import com.network.proxy.vpn.transport.protocol.UDPPacketFactory +import com.network.proxy.vpn.util.PacketUtil.getOutput +import com.network.proxy.vpn.util.PacketUtil.intToIPAddress +import com.network.proxy.vpn.util.PacketUtil.isPacketCorrupted +import com.network.proxy.vpn.util.TLS.getDomain +import com.network.proxy.vpn.util.TLS.isTLSClientHello +import java.io.IOException +import java.net.InetAddress +import java.net.InetSocketAddress +import java.net.SocketAddress +import java.nio.ByteBuffer +import java.nio.channels.SelectionKey +import java.nio.channels.SocketChannel +import java.util.concurrent.ExecutorService +import java.util.concurrent.SynchronousQueue +import java.util.concurrent.ThreadPoolExecutor +import java.util.concurrent.TimeUnit + +class ConnectionHandler( + private val manager: ConnectionManager, + private val nioService: SocketNIODataService, + private val writer: ClientPacketWriter +) { + + private val pingThreadPool: ExecutorService = ThreadPoolExecutor( + 1, 20, // 1 - 20 parallel pings max + 60L, TimeUnit.SECONDS, + SynchronousQueue(), + ThreadPoolExecutor.DiscardPolicy() // Replace running pings if there's too many + ) + + /** + * Handle unknown raw IP packet data + * + * @param stream ByteBuffer to be read + */ + @Throws(IOException::class) + fun handlePacket(stream: ByteBuffer) { + val rawPacket = ByteArray(stream.limit()) + stream[rawPacket, 0, stream.limit()] + stream.rewind() + + val ipHeader = IPPacketFactory.createIP4Header(stream) + + if (ipHeader == null) { + stream.rewind() + Log.w(TAG, "Malformed IP packet ") + return + } + if (ipHeader.protocol.toInt() == 6) { + handleTCPPacket(stream, ipHeader) + } else if (ipHeader.protocol.toInt() == 17) { + handleUDPPacket(stream, ipHeader) + } else if (ipHeader.protocol.toInt() == 1) { + handleICMPPacket(stream, ipHeader) + } else { + Log.w(TAG, "Unsupported IP protocol: " + ipHeader.protocol) + } + } + + @Throws(IOException::class) + private fun handleUDPPacket(clientPacketData: ByteBuffer, ipHeader: IP4Header) { + val udpHeader = UDPPacketFactory.createUDPHeader(clientPacketData) + var connection = manager.getConnection( + Protocol.UDP, + ipHeader.destinationIP, udpHeader.destinationPort, + ipHeader.sourceIP, udpHeader.sourcePort + ) + val newSession = connection == null + if (connection == null) { + connection = manager.createUDPConnection( + ipHeader.destinationIP, udpHeader.destinationPort, + ipHeader.sourceIP, udpHeader.sourcePort + ) + } + synchronized(connection) { + connection.lastIpHeader = ipHeader + connection.lastUdpHeader = udpHeader + manager.addClientData(clientPacketData, connection) + connection.isDataForSendingReady = true + + // We don't register the session until it's fully populated (as above) + if (newSession) nioService.registerSession(connection) + + // Ping the NIO thread to write this, when the session is next writable + connection.subscribeKey(SelectionKey.OP_WRITE) + nioService.refreshSelect(connection) + } + manager.keepSessionAlive(connection) + } + + /** + * 是否支持协议 + */ + private val methods: List = + mutableListOf("GET", "POST", "PUT", "DELETE", "HEAD", "OPTIONS", "TRACE", "CONNECT") + + private fun supperProtocol(packetData: ByteBuffer): Boolean { + val position = packetData.position() + //判断是否是ssl握手 + if (isTLSClientHello(packetData) && getDomain(packetData) != null) { + packetData.position(position) + return true + } + packetData.position(position) + for (method in methods) { + if (packetData.remaining() < method.length) { + continue + } + val bytes = ByteArray(method.length) + for (i in bytes.indices) { + bytes[i] = packetData[position + i] + } + if (method.equals(String(bytes), ignoreCase = true)) { + return true + } + } + return false + } + + /** + * 获取代理地址 + */ + private fun getProxyAddress( + packetData: ByteBuffer, + destinationIP: Int, + destinationPort: Int + ): SocketAddress { + val supperProtocol = supperProtocol(packetData) + var socketAddress: SocketAddress? = null + if (supperProtocol) { + socketAddress = manager.proxyAddress + } + if (socketAddress == null) { + val ips = intToIPAddress(destinationIP) + socketAddress = InetSocketAddress(ips, destinationPort) + } + return socketAddress + } + + @Throws(IOException::class) + private fun handleTCPPacket(clientPacketData: ByteBuffer, ip4Header: IP4Header) { + val tcpHeader = TCPPacketFactory.createTCPHeader(clientPacketData) + val dataLength = clientPacketData.limit() - clientPacketData.position() + val sourceIP = ip4Header.sourceIP + val destinationIP = ip4Header.destinationIP + val sourcePort = tcpHeader.getSourcePort() + val destinationPort = tcpHeader.getDestinationPort() + if (tcpHeader.isSYN()) { + // 3-way handshake + create new session + replySynAck(ip4Header, tcpHeader) + } else if (tcpHeader.isACK()) { + val key = + getConnectionKey(Protocol.TCP, destinationIP, destinationPort, sourceIP, sourcePort) + val connection = manager.getConnectionByKey(key) + if (connection == null) { + Log.w(TAG, "Ack for unknown session: $key") + if (tcpHeader.isFIN()) { + sendLastAck(ip4Header, tcpHeader) + } else if (!tcpHeader.isRST()) { + sendRstPacket(ip4Header, tcpHeader, dataLength) + } + return + } + synchronized(connection) { + connection.lastIpHeader = ip4Header + connection.lastTcpHeader = tcpHeader + + //any data from client? + if (dataLength > 0) { + if (!connection.isInitConnect) { + connection.isInitConnect = true + val proxyAddress = + getProxyAddress(clientPacketData, destinationIP, destinationPort) + try { + val channel = + connection.channel as SocketChannel? + val connected = channel!!.connect(proxyAddress) + connection.isConnected = connected + nioService.registerSession(connection) + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.N) { + Log.d( + TAG, + "Proxy Initiate connecting key:" + key + " " + channel.localAddress + " to remote tcp server: " + channel.remoteAddress + ) + } + } catch (e: Exception) { + val ips = intToIPAddress(destinationIP) + Log.w( + TAG, + "Failed to reconnect to $ips:$destinationPort", + e + ) + } + } + + //accumulate data from client + if (connection.recSequence == 0L || tcpHeader.sequenceNumber >= connection.recSequence) { + val addedLength = manager.addClientData(clientPacketData, connection) + //send ack to client only if new data was added + sendAck(ip4Header, tcpHeader, addedLength, connection) + } else { + sendAckForDisorder(ip4Header, tcpHeader, dataLength) + } + } else { + //an ack from client for previously sent data + acceptAck(tcpHeader, connection) + if (connection.isClosingConnection) { + sendFinAck(ip4Header, tcpHeader, connection) + } else if (connection.isAckedToFin && !tcpHeader.isFIN()) { + //the last ACK from client after FIN-ACK flag was sent + manager.closeConnection( + Protocol.TCP, + destinationIP, + destinationPort, + sourceIP, + sourcePort + ) + // Log.d(TAG, "got last ACK after FIN, session is now closed."); + } + } + //received the last segment of data from vpn client + if (tcpHeader.isPSH()) { + // Tell the NIO thread to immediately send data to the destination + pushDataToDestination(connection, tcpHeader) + } else if (tcpHeader.isFIN()) { + //fin from vpn client is the last packet + //ack it +// Log.d(TAG, "FIN from vpn client, will ack it."); + ackFinAck(ip4Header, tcpHeader, connection) + } else if (tcpHeader.isRST()) { + resetTCPConnection(ip4Header, tcpHeader) + } + if (!connection.isAbortingConnection) { + manager.keepSessionAlive(connection) + } + } + } else if (tcpHeader.isFIN()) { + //case client sent FIN without ACK + val connection = manager.getConnection( + Protocol.TCP, + destinationIP, + destinationPort, + sourceIP, + sourcePort + ) + if (connection == null) ackFinAck( + ip4Header, + tcpHeader, + null + ) else manager.keepSessionAlive(connection) + } else if (tcpHeader.isRST()) { + resetTCPConnection(ip4Header, tcpHeader) + } else { + Log.d(TAG, "unknown TCP flag") + val str1 = getOutput(ip4Header, tcpHeader, clientPacketData.array()) + Log.d(TAG, ">>>>>>>> Received from client <<<<<<<<<<") + Log.d(TAG, str1) + Log.d(TAG, ">>>>>>>>>>>>>>>>>>>end receiving from client>>>>>>>>>>>>>>>>>>>>>") + } + } + + private fun sendRstPacket(ip: IP4Header, tcp: TCPHeader, dataLength: Int) { + val data = TCPPacketFactory.createRstData(ip, tcp, dataLength) + writer.write(data) + Log.d( + TAG, "Sent RST Packet to client with dest => " + + intToIPAddress(ip.destinationIP) + ":" + + tcp.getDestinationPort() + ) + } + + private fun sendLastAck(ip: IP4Header, tcp: TCPHeader) { + val data = TCPPacketFactory.createResponseAckData(ip, tcp, tcp.sequenceNumber + 1) + writer.write(data) +// Log.d(TAG,"Sent last ACK Packet to client with dest => " + +// PacketUtil.intToIPAddress(ip.getDestinationIP()) + ":" + +// tcp.getDestinationPort()); + } + + private fun ackFinAck(ip: IP4Header, tcp: TCPHeader, connection: Connection?) { + val ack = tcp.sequenceNumber + 1 + val seq = tcp.ackNumber + val data = TCPPacketFactory.createFinAckData(ip, tcp, ack, seq, isFin = true, isAck = true) + writer.write(data) + if (connection != null) { + connection.cancelKey() + manager.closeConnection(connection) + // Log.d(TAG,"ACK to client's FIN and close session => "+PacketUtil.intToIPAddress(ip.getDestinationIP())+":"+tcp.getDestinationPort() +// +"-"+PacketUtil.intToIPAddress(ip.getSourceIP())+":"+tcp.getSourcePort()); + } + } + + private fun sendFinAck(ip: IP4Header, tcp: TCPHeader, connection: Connection) { + val ack = tcp.sequenceNumber + val seq = tcp.ackNumber + val data = TCPPacketFactory.createFinAckData(ip, tcp, ack, seq, isFin = true, isAck = false) + val stream = ByteBuffer.wrap(data) + writer.write(data) + Log.d(TAG, "00000000000 FIN-ACK packet data to vpn client 000000000000") + var vpnIp: IP4Header? = null + try { + vpnIp = IPPacketFactory.createIP4Header(stream) + } catch (e: Exception) { + e.printStackTrace() + } + var vpnTcp: TCPHeader? = null + try { + if (vpnIp != null) vpnTcp = TCPPacketFactory.createTCPHeader(stream) + } catch (e: Exception) { + e.printStackTrace() + } + if (vpnIp != null && vpnTcp != null) { + val logOut = getOutput(vpnIp, vpnTcp, data) + Log.d(TAG, logOut) + } + Log.d(TAG, "0000000000000 finished sending FIN-ACK packet to vpn client 000000000000") + connection.sendNext = seq + 1 + //avoid re-sending it, from here client should take care the rest + connection.isClosingConnection = false + } + + private fun pushDataToDestination(connection: Connection, tcp: TCPHeader) { + connection.isDataForSendingReady = true + connection.timestampReplyTo = tcp.timeStampSender + connection.timestampSender = System.currentTimeMillis().toInt() + + // Ping the NIO thread to write this, when the session is next writable + connection.subscribeKey(SelectionKey.OP_WRITE) + nioService.refreshSelect(connection) + } + + /** + * send acknowledgment packet to VPN client + * + * @param acceptedDataLength Data Length + */ + private fun sendAck( + ipHeader: IP4Header, tcpHeader: TCPHeader, acceptedDataLength: Int, connection: Connection + ) { + val ackNumber = connection.recSequence + acceptedDataLength + connection.recSequence = ackNumber + val ackData = TCPPacketFactory.createResponseAckData(ipHeader, tcpHeader, ackNumber) + writer.write(ackData) + } + + /** + * resend the last acknowledgment packet to VPN client, e.g. when an unexpected out of order + * packet arrives. + */ + private fun resendAck(connection: Connection) { + val data = TCPPacketFactory.createResponseAckData( + connection.lastIpHeader!!, + connection.lastTcpHeader!!, + connection.recSequence + ) + writer.write(data) + } + + private fun sendAckForDisorder( + ipHeader: IP4Header, tcpHeader: TCPHeader, acceptedDataLength: Int + ) { + val ackNumber = tcpHeader.sequenceNumber + acceptedDataLength + Log.e( + TAG, "sent disorder ack, ack# " + tcpHeader.sequenceNumber + + " + " + acceptedDataLength + " = " + ackNumber + ) + val data = TCPPacketFactory.createResponseAckData(ipHeader, tcpHeader, ackNumber) + writer.write(data) + } + + /** + * acknowledge a packet. + * + * @param tcpHeader TCP Header + */ + private fun acceptAck(tcpHeader: TCPHeader, connection: Connection) { + val isCorrupted = isPacketCorrupted(tcpHeader) + +// connection.setPacketCorrupted(isCorrupted); + if (isCorrupted) { + Log.e(TAG, "prev packet was corrupted, last ack# " + tcpHeader.ackNumber) + } + if (tcpHeader.ackNumber > connection.sendUnAck || + tcpHeader.ackNumber == connection.sendNext + ) { +// connection.setAcked(true); + connection.sendUnAck = tcpHeader.ackNumber + connection.recSequence = tcpHeader.sequenceNumber + connection.timestampReplyTo = tcpHeader.timeStampSender + connection.timestampSender = System.currentTimeMillis().toInt() + } else { + Log.d( + TAG, + "Not Accepting ack# " + tcpHeader.ackNumber + " , it should be: " + connection.sendNext + ) + Log.d(TAG, "Prev sendUnAck: " + connection.sendUnAck) + // connection.setAcked(false); + } + } + + /** + * set connection as aborting so that background worker will close it. + * + * @param ip IP + * @param tcp TCP + */ + private fun resetTCPConnection(ip: IP4Header, tcp: TCPHeader) { + val session = manager.getConnection( + Protocol.TCP, + ip.destinationIP, tcp.getDestinationPort(), + ip.sourceIP, tcp.getSourcePort() + ) + if (session != null) { + synchronized(session) { session.isAbortingConnection = true } + } + } + + /** + * create a new client's session and SYN-ACK packet data to respond to client + */ + @Throws(IOException::class) + private fun replySynAck(ipHeader: IP4Header, tcpHeader: TCPHeader) { + ipHeader.identification = 0 + val packet = TCPPacketFactory.createSynAckPacketData(ipHeader, tcpHeader) + val tcpTransport = packet.transportHeader as TCPHeader + val connection = manager.createTCPConnection( + ipHeader.destinationIP, tcpHeader.getDestinationPort(), + ipHeader.sourceIP, tcpHeader.getSourcePort() + ) + if (connection.lastIpHeader != null) { + // We have an existing session for this connection! We've somehow received a SYN + // for an existing socket (or some kind of other race). We resend the last ACK + // for this session, rejecting this SYN. Not clear why this happens, but it can. + resendAck(connection) + return + } + synchronized(connection) { + connection.maxSegmentSize = tcpTransport.maxSegmentSize.toInt() + connection.sendUnAck = tcpTransport.sequenceNumber + connection.sendNext = tcpTransport.sequenceNumber + 1 + //client initial sequence has been incremented by 1 and set to ack + connection.recSequence = tcpTransport.ackNumber + connection.lastIpHeader = ipHeader + connection.lastTcpHeader = tcpHeader + if (connection.isInitConnect) { + nioService.registerSession(connection) + } + writer.write(packet.buffer) + } + } + + private fun handleICMPPacket(clientPacketData: ByteBuffer, ipHeader: IP4Header) { + val requestPacket = ICMPPacketFactory.parseICMPPacket(clientPacketData) + Log.d(TAG, "Got an ICMP ping packet, type $requestPacket") + if (requestPacket.type == ICMPPacket.DESTINATION_UNREACHABLE_TYPE) { + // This is a packet from the phone, telling somebody that a destination is unreachable. + // Might be caused by issues on our end, but it's unclear what kind of issues. Regardless, + // we can't send ICMP messages ourselves or react usefully, so we drop these silently. + return + } else require(requestPacket.type == ICMPPacket.ECHO_REQUEST_TYPE) { + // We only actually support outgoing ping packets. Loudly drop anything else: + "Unknown ICMP type (" + requestPacket.type + "). Only echo requests are supported" + } + pingThreadPool.execute(object : Runnable { + override fun run() { + try { + if (!isReachable(intToIPAddress(ipHeader.destinationIP))) { + Log.d(TAG, "Failed ping, ignoring") + return + } + val response = ICMPPacketFactory.buildSuccessPacket(requestPacket) + + // Flip the address + val destination = ipHeader.destinationIP + val source = ipHeader.sourceIP + ipHeader.sourceIP = destination + ipHeader.destinationIP = source + val responseData = ICMPPacketFactory.packetToBuffer(ipHeader, response) + Log.d(TAG, "Successful ping response") + writer.write(responseData) + } catch (e: Exception) { + Log.w(TAG, "Handling ICMP failed with " + e.message) + return + } + } + + private fun isReachable(ipAddress: String): Boolean { + return try { + InetAddress.getByName(ipAddress).isReachable(10000) + } catch (e: IOException) { + false + } + } + }) + } +} \ No newline at end of file diff --git a/android/app/src/main/kotlin/com/network/proxy/vpn/ConnectionManager.kt b/android/app/src/main/kotlin/com/network/proxy/vpn/ConnectionManager.kt new file mode 100644 index 0000000..9b7f49f --- /dev/null +++ b/android/app/src/main/kotlin/com/network/proxy/vpn/ConnectionManager.kt @@ -0,0 +1,159 @@ +package com.network.proxy.vpn + +import android.os.Build +import android.util.Log +import com.network.proxy.vpn.socket.CloseableConnection +import com.network.proxy.vpn.socket.Constant +import com.network.proxy.vpn.socket.ProtectSocketHolder.Companion.protect +import com.network.proxy.vpn.util.PacketUtil +import java.io.IOException +import java.net.InetSocketAddress +import java.net.SocketAddress +import java.nio.ByteBuffer +import java.nio.channels.DatagramChannel +import java.nio.channels.SocketChannel +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.ConcurrentMap + +/** + * 管理VPN客户端的连接 + */ +class ConnectionManager : CloseableConnection { + private val table: ConcurrentMap = ConcurrentHashMap() + var proxyAddress: InetSocketAddress? = null + + var DEFAULT_PORTS: List = listOf( + 80, // HTTP + 443, // HTTPS + 8080, // Common local dev ports + 8000, 8080, 8888, 9000 // Common local dev ports + ) + + override fun closeConnection(connection: Connection) { + closeConnection( + connection.protocol, connection.destinationIp, connection.destinationPort, + connection.sourceIp, connection.sourcePort + ) + } + + /** + * 从内存中删除连接,然后关闭套接字。 + * + */ + fun closeConnection(protocol: Protocol, ip: Int, port: Int, srcIp: Int, srcPort: Int) { + val key = Connection.getConnectionKey(protocol, ip, port, srcIp, srcPort) + val session: Connection? = table.remove(key) + session?.let { + val channel = session.channel + try { + channel?.close() + } catch (e: IOException) { + e.printStackTrace() + } + } + } + + fun getConnection( + protocol: Protocol, ip: Int, port: Int, srcIp: Int, srcPort: Int + ): Connection? { + val key = Connection.getConnectionKey(protocol, ip, port, srcIp, srcPort) + return getConnectionByKey(key) + } + + fun getConnectionByKey(key: String?): Connection? { + return table[key] + } + + /** + * 创建tcp连接 + */ + fun createTCPConnection(ip: Int, port: Int, srcIp: Int, srcPort: Int): Connection { + val key = Connection.getConnectionKey(Protocol.TCP, ip, port, srcIp, srcPort) + val existingConnection: Connection? = table[key] + if (existingConnection != null) { + return existingConnection + } + + val connection = Connection(Protocol.TCP, srcIp, srcPort, ip, port, this) + + val channel: SocketChannel = SocketChannel.open() + channel.socket().keepAlive = true + channel.socket().tcpNoDelay = true + channel.socket().soTimeout = 0 + channel.socket().receiveBufferSize = Constant.MAX_RECEIVE_BUFFER_SIZE + channel.configureBlocking(false) + + Log.d(TAG, "created new SocketChannel for $key") + + protect(channel.socket()) + + connection.channel = channel + + val socketAddress: SocketAddress? = null +// if (!DEFAULT_PORTS.contains(port)) { +// socketAddress = new InetSocketAddress(ips, port); +// } + + // if (!DEFAULT_PORTS.contains(port)) { +// socketAddress = new InetSocketAddress(ips, port); +// } + connection.isInitConnect = socketAddress != null + + if (socketAddress != null) { + val connected = channel.connect(socketAddress) + connection.isConnected = connected + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.N) { + Log.d( + TAG, + "Initiate connecting " + channel.localAddress + " to remote tcp server: " + channel.remoteAddress + ) + } + } + + table[key] = connection + return connection + } + + + @Throws(IOException::class) + fun createUDPConnection(ip: Int, port: Int, srcIp: Int, srcPort: Int): Connection { + val keys = Connection.getConnectionKey(Protocol.UDP, ip, port, srcIp, srcPort) + + val existingConnection: Connection? = table[keys] + if (existingConnection != null) return existingConnection + + val connection = Connection(Protocol.UDP, srcIp, srcPort, ip, port, this) + val channel: DatagramChannel = DatagramChannel.open() + channel.socket().soTimeout = 0 + channel.configureBlocking(false) + protect(channel.socket()) + connection.channel = channel + + // Initiate connection early to reduce latency + val ips = PacketUtil.intToIPAddress(ip) + val socketAddress: SocketAddress = InetSocketAddress(ips, port) + channel.connect(socketAddress) + connection.isConnected = channel.isConnected + table[keys] = connection + + return connection + } + + /** + * 添加来自客户端的数据,该数据稍后将在接收到PSH标志时发送到目的服务器。 + */ + fun addClientData(buffer: ByteBuffer, session: Connection): Int { + return if (buffer.limit() <= buffer.position()) 0 else session.setSendingData(buffer) + } + + /** + * 阻止java垃圾收集器收集会话 + */ + fun keepSessionAlive(connection: Connection) { + val key = Connection.getConnectionKey( + connection.protocol, connection.destinationIp, connection.destinationPort, + connection.sourceIp, connection.sourcePort + ) + table[key] = connection + } +} \ No newline at end of file diff --git a/android/app/src/main/kotlin/com/network/proxy/vpn/Protocol.java b/android/app/src/main/kotlin/com/network/proxy/vpn/Protocol.java new file mode 100644 index 0000000..2bd082d --- /dev/null +++ b/android/app/src/main/kotlin/com/network/proxy/vpn/Protocol.java @@ -0,0 +1,6 @@ +package com.network.proxy.vpn; + +public enum Protocol { + TCP, + UDP +} diff --git a/android/app/src/main/kotlin/com/network/proxy/vpn/ProxyVpnThread.kt b/android/app/src/main/kotlin/com/network/proxy/vpn/ProxyVpnThread.kt new file mode 100644 index 0000000..d16ae52 --- /dev/null +++ b/android/app/src/main/kotlin/com/network/proxy/vpn/ProxyVpnThread.kt @@ -0,0 +1,105 @@ +package com.network.proxy.vpn + +import android.os.ParcelFileDescriptor +import android.util.Log +import com.network.proxy.ProxyVpnService.Companion.MAX_PACKET_LEN +import com.network.proxy.vpn.socket.ClientPacketWriter +import com.network.proxy.vpn.socket.SocketNIODataService +import java.io.FileInputStream +import java.io.FileOutputStream +import java.io.InterruptedIOException +import java.net.InetSocketAddress +import java.nio.ByteBuffer + + +/** + * VPN线程,负责处理VPN接收到的数据包 + */ +class ProxyVpnThread( + vpnInterface: ParcelFileDescriptor, + proxyHost: String, + proxyPort: Int, +) : Thread("Vpn thread") { + companion object { + const val TAG = "ProxyVpnThread" + } + + @Volatile + private var running = false + + private val vpnReadChannel = FileInputStream(vpnInterface.fileDescriptor).channel + + // 此VPN接收的来自上游服务器的数据包 + private val vpnWriteStream = FileOutputStream(vpnInterface.fileDescriptor) + private val vpnPacketWriter = ClientPacketWriter(vpnWriteStream) + private val vpnPacketWriterThread = Thread(vpnPacketWriter) + + // Background service & task for non-blocking socket + private val nioService = SocketNIODataService(vpnPacketWriter) + private val dataServiceThread = Thread(nioService, "Socket NIO thread") + + private val manager = ConnectionManager().apply { + //流量转发到代理地址 + this.proxyAddress = InetSocketAddress(proxyHost, proxyPort) + } + + private val handler = ConnectionHandler(manager, nioService, vpnPacketWriter) + + private var currentThread: Thread? = null + + override fun run() { + Log.i(TAG, "Vpn thread starting") + currentThread = currentThread() + dataServiceThread.start() + vpnPacketWriterThread.start() + + val readBuffer = ByteBuffer.allocate(MAX_PACKET_LEN) + running = true + while (running) { + try { + val length = vpnReadChannel.read(readBuffer) + + if (length > 0) { + try { + readBuffer.flip() + val byteArray = ByteArray(length) + readBuffer.get(byteArray) + + val packet = ByteBuffer.wrap(byteArray) + handler.handlePacket(packet) + } catch (e: Exception) { + val errorMessage = (e.message ?: e.toString()) + Log.e(TAG, errorMessage, e) + } + + readBuffer.clear() + } else { + sleep(50) + } + } catch (e: InterruptedException) { + Log.i(TAG, "Sleep interrupted: " + e.message) + } catch (e: InterruptedIOException) { + Log.i(TAG, "Read interrupted: " + e.message) + } catch (e: Exception) { + val errorMessage = (e.message ?: e.toString()) + Log.e(TAG, errorMessage, e) + } + } + + Log.i(TAG, "Vpn thread stop") + } + + @Synchronized + fun stopThread() { + if (running) { + running = false + nioService.shutdown() + dataServiceThread.interrupt() + + vpnPacketWriter.shutdown() + vpnPacketWriterThread.interrupt() + currentThread?.interrupt() + } + } + +} diff --git a/android/app/src/main/kotlin/com/network/proxy/vpn/socket/ClientPacketWriter.kt b/android/app/src/main/kotlin/com/network/proxy/vpn/socket/ClientPacketWriter.kt new file mode 100644 index 0000000..811b25c --- /dev/null +++ b/android/app/src/main/kotlin/com/network/proxy/vpn/socket/ClientPacketWriter.kt @@ -0,0 +1,46 @@ +package com.network.proxy.vpn.socket + +import android.util.Log +import java.io.FileOutputStream +import java.io.IOException +import java.util.concurrent.BlockingDeque +import java.util.concurrent.LinkedBlockingDeque +import kotlin.concurrent.Volatile + +class ClientPacketWriter(private val clientWriter: FileOutputStream) : Runnable { + companion object { + private const val TAG: String = "ClientPacketWriter" + private const val MAX_PACKET_LEN = 32767 + } + + @Volatile + private var shutdown = false + + private val packetQueue: BlockingDeque = LinkedBlockingDeque() + + fun write(data: ByteArray) { + if (data.size > MAX_PACKET_LEN) throw Error("Packet too large") + packetQueue.addLast(data) + } + + fun shutdown() { + this.shutdown = true + } + + override fun run() { + while (!this.shutdown) { + try { + val data: ByteArray = this.packetQueue.take() + try { + this.clientWriter.write(data) + } catch (e: IOException) { + Log.e(TAG, "Error writing $shutdown data.length bytes to the VPN") + e.printStackTrace() + this.packetQueue.addFirst(data) // Put the data back, so it's resent + Thread.sleep(10) // Add an arbitrary tiny pause, in case that helps + } + } catch (ignored: InterruptedException) { + } + } + } +} diff --git a/android/app/src/main/kotlin/com/network/proxy/vpn/socket/CloseableConnection.kt b/android/app/src/main/kotlin/com/network/proxy/vpn/socket/CloseableConnection.kt new file mode 100644 index 0000000..1b13a5f --- /dev/null +++ b/android/app/src/main/kotlin/com/network/proxy/vpn/socket/CloseableConnection.kt @@ -0,0 +1,10 @@ +package com.network.proxy.vpn.socket + +import com.network.proxy.vpn.Connection + +interface CloseableConnection { + /** + * 关闭连接 + */ + fun closeConnection(session: Connection) +} \ No newline at end of file diff --git a/android/app/src/main/kotlin/com/network/proxy/vpn/socket/Constant.kt b/android/app/src/main/kotlin/com/network/proxy/vpn/socket/Constant.kt new file mode 100644 index 0000000..c84c234 --- /dev/null +++ b/android/app/src/main/kotlin/com/network/proxy/vpn/socket/Constant.kt @@ -0,0 +1,5 @@ +package com.network.proxy.vpn.socket + +object Constant { + const val MAX_RECEIVE_BUFFER_SIZE = 65535 +} \ No newline at end of file diff --git a/android/app/src/main/kotlin/com/network/proxy/vpn/socket/SocketChannelReader.java b/android/app/src/main/kotlin/com/network/proxy/vpn/socket/SocketChannelReader.java new file mode 100644 index 0000000..0b12390 --- /dev/null +++ b/android/app/src/main/kotlin/com/network/proxy/vpn/socket/SocketChannelReader.java @@ -0,0 +1,211 @@ +package com.network.proxy.vpn.socket; + +import androidx.annotation.NonNull; + +import android.util.Log; + +import com.network.proxy.vpn.Connection; +import com.network.proxy.vpn.TagKt; +import com.network.proxy.vpn.transport.protocol.IP4Header; +import com.network.proxy.vpn.transport.protocol.TCPHeader; +import com.network.proxy.vpn.transport.protocol.TCPPacketFactory; +import com.network.proxy.vpn.transport.protocol.UDPPacketFactory; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.ClosedByInterruptException; +import java.nio.channels.ClosedChannelException; +import java.nio.channels.DatagramChannel; +import java.nio.channels.NotYetConnectedException; +import java.nio.channels.SelectionKey; +import java.nio.channels.SocketChannel; +import java.nio.channels.spi.AbstractSelectableChannel; + + +/** + * Takes a session, and reads all available upstream data back into it. + * Used by the NIO thread, and run synchronously as part of that non-blocking loop. + */ +class SocketChannelReader { + + private final String TAG = TagKt.getTAG(this); + + private final ClientPacketWriter writer; + + public SocketChannelReader(ClientPacketWriter writer) { + this.writer = writer; + } + + public void read(Connection connection) { + AbstractSelectableChannel channel = connection.getChannel(); + + if (channel instanceof SocketChannel) { + readTCP(connection); + } else if (channel instanceof DatagramChannel) { + readUDP(connection); + } else { + return; + } + + // Resubscribe to reads, so that we're triggered again if more data arrives later. + connection.subscribeKey(SelectionKey.OP_READ); + + if (connection.isAbortingConnection()) { + Log.d(TAG, "removing aborted connection -> " + connection); + connection.cancelKey(); + if (channel instanceof SocketChannel) { + try { + SocketChannel socketChannel = (SocketChannel) channel; + if (socketChannel.isConnected()) { + socketChannel.close(); + } + } catch (IOException e) { + Log.e(TAG, e.toString()); + } + } else { + try { + DatagramChannel datagramChannel = (DatagramChannel) channel; + if (datagramChannel.isConnected()) { + datagramChannel.close(); + } + } catch (IOException e) { + e.printStackTrace(); + } + } + connection.closeConnection(); + } + } + + private void readTCP(@NonNull Connection connection) { + if (connection.isAbortingConnection()) { + return; + } + + SocketChannel channel = (SocketChannel) connection.getChannel(); + ByteBuffer buffer = ByteBuffer.allocate(Constant.MAX_RECEIVE_BUFFER_SIZE); + int len; + + try { + do { + len = channel.read(buffer); + if (len > 0) { //-1 mean it reach the end of stream + sendToRequester(buffer, len, connection); + buffer.clear(); + } else if (len == -1) { +// Log.d(TAG,"End of data from remote server, will send FIN to client"); + Log.d(TAG, "send FIN to: " + connection); + sendFin(connection); + connection.setAbortingConnection(true); + } + } while (len > 0); + } catch (NotYetConnectedException e) { + Log.e(TAG, "socket not connected"); + } catch (ClosedByInterruptException e) { + Log.e(TAG, "ClosedByInterruptException reading SocketChannel: " + e.getMessage()); + } catch (ClosedChannelException e) { + Log.e(TAG, "ClosedChannelException reading SocketChannel: " + e.getMessage()); + } catch (IOException e) { + Log.e(TAG, "Error reading data from SocketChannel: " + e.getMessage()); + connection.setAbortingConnection(true); + } + } + + private void sendToRequester(ByteBuffer buffer, int dataSize, @NonNull Connection connection) { + // Last piece of data is usually smaller than MAX_RECEIVE_BUFFER_SIZE. We use this as a + // trigger to set PSH on the resulting TCP packet that goes to the VPN. + connection.setHasReceivedLastSegment(dataSize < Constant.MAX_RECEIVE_BUFFER_SIZE); + + buffer.limit(dataSize); + buffer.flip(); + // TODO should allocate new byte array? + byte[] data = new byte[dataSize]; + System.arraycopy(buffer.array(), 0, data, 0, dataSize); + connection.addReceivedData(data); + //pushing all data to vpn client + while (connection.hasReceivedData()) { + pushDataToClient(connection); + } + } + + /** + * create packet data and send it to VPN client + */ + private void pushDataToClient(@NonNull Connection connection) { + if (!connection.hasReceivedData()) { + //no data to send + Log.d(TAG, "no data for vpn client"); + } + + IP4Header ipHeader = connection.getLastIpHeader(); + TCPHeader tcpheader = connection.getLastTcpHeader(); + // TODO What does 60 mean? + int max = connection.getMaxSegmentSize() - 60; + + if (max < 1) { + max = 1024; + } + + byte[] packetBody = connection.getReceivedData(max); + if (packetBody != null && packetBody.length > 0) { + long unAck = connection.getSendNext(); + long nextUnAck = connection.getSendNext() + packetBody.length; + connection.setSendNext((int) nextUnAck); + //we need this data later on for retransmission +// connection.setUnackData(packetBody); +// connection.setResendPacketCounter(0); + + byte[] data = TCPPacketFactory.createResponsePacketData(ipHeader, + tcpheader, packetBody, connection.getHasReceivedLastSegment(), + connection.getRecSequence(), (int) unAck, + connection.getTimestampSender(), connection.getTimestampReplyTo()); + + writer.write(data); + } + } + + private void sendFin(Connection connection) { + final IP4Header ipHeader = connection.getLastIpHeader(); + final TCPHeader tcpheader = connection.getLastTcpHeader(); + final byte[] data = TCPPacketFactory.INSTANCE.createFinData(ipHeader, tcpheader, + connection.getRecSequence(), connection.getSendNext(), + connection.getTimestampSender(), connection.getTimestampReplyTo()); + + writer.write(data); + } + + private void readUDP(Connection connection) { + DatagramChannel channel = (DatagramChannel) connection.getChannel(); + ByteBuffer buffer = ByteBuffer.allocate(Constant.MAX_RECEIVE_BUFFER_SIZE); + int len; + + try { + do { + if (connection.isAbortingConnection()) { + break; + } + + len = channel.read(buffer); + if (len > 0) { + buffer.limit(len); + buffer.flip(); + + //create UDP packet + byte[] data = new byte[len]; + System.arraycopy(buffer.array(), 0, data, 0, len); + byte[] packetData = UDPPacketFactory.createResponsePacket( + connection.getLastIpHeader(), connection.getLastUdpHeader(), data); + + //write to client + writer.write(packetData); + + buffer.clear(); + } + } while (len > 0); + } catch (NotYetConnectedException ex) { + Log.e(TAG, "failed to read from unconnected UDP socket"); + } catch (IOException e) { + Log.e(TAG, "Failed to read from UDP socket, aborting connection"); + connection.setAbortingConnection(true); + } + } +} diff --git a/android/app/src/main/kotlin/com/network/proxy/vpn/socket/SocketChannelWriter.java b/android/app/src/main/kotlin/com/network/proxy/vpn/socket/SocketChannelWriter.java new file mode 100644 index 0000000..7184075 --- /dev/null +++ b/android/app/src/main/kotlin/com/network/proxy/vpn/socket/SocketChannelWriter.java @@ -0,0 +1,148 @@ +package com.network.proxy.vpn.socket; + +import androidx.annotation.NonNull; +import android.util.Log; + + +import com.network.proxy.vpn.Connection; +import com.network.proxy.vpn.TagKt; +import com.network.proxy.vpn.transport.protocol.TCPPacketFactory; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.DatagramChannel; +import java.nio.channels.NotYetConnectedException; +import java.nio.channels.SelectionKey; +import java.nio.channels.SocketChannel; +import java.nio.channels.spi.AbstractSelectableChannel; + + +/** + * Takes a VPN session, and writes all received data from it to the upstream channel. + *

+ * If any writes fail, it resubscribes to OP_WRITE, and tries again next time + * that fires (as soon as the channel is ready for more data). + *

+ * Used by the NIO thread, and run synchronously as part of that non-blocking loop. + */ +public class SocketChannelWriter { + private final String TAG = TagKt.getTAG(this); + + private final ClientPacketWriter writer; + + SocketChannelWriter(ClientPacketWriter writer) { + this.writer = writer; + } + + public void write(@NonNull Connection connection) { + AbstractSelectableChannel channel = connection.getChannel(); + if (channel instanceof SocketChannel) { + writeTCP(connection); + } else if(channel instanceof DatagramChannel) { + writeUDP(connection); + } else { + // We only ever create TCP & UDP channels, so this should never happen + throw new IllegalArgumentException("Unexpected channel type: " + channel); + } + + if (connection.isAbortingConnection()) { + Log.d(TAG,"removing aborted connection -> " + connection); + connection.cancelKey(); + + if (channel instanceof SocketChannel) { + try { + SocketChannel socketChannel = (SocketChannel) channel; + if (socketChannel.isConnected()) { + socketChannel.close(); + } + } catch (IOException e) { + e.printStackTrace(); + } + } else { + try { + DatagramChannel datagramChannel = (DatagramChannel) channel; + if (datagramChannel.isConnected()) { + datagramChannel.close(); + } + } catch (IOException e) { + e.printStackTrace(); + } + } + + connection.closeConnection(); + } + } + + private void writeUDP(Connection connection) { + try { + writePendingData(connection); +// Date dt = new Date(); +// connection.connectionStartTime = dt.getTime(); + }catch(NotYetConnectedException ex2){ + connection.setAbortingConnection(true); + Log.e(TAG,"Error writing to unconnected-UDP server, will abort current connection: "+ex2.getMessage()); + } catch (IOException e) { + connection.setAbortingConnection(true); + e.printStackTrace(); + Log.e(TAG,"Error writing to UDP server, will abort connection: "+e.getMessage()); + } + } + + private void writeTCP(Connection connection) { + try { + writePendingData(connection); + } catch (NotYetConnectedException ex) { + Log.e(TAG,"failed to write to unconnected socket: " + ex.getMessage()); + } catch (IOException e) { + Log.e(TAG,"Error writing to server: " + e); + + //close connection with vpn client + byte[] rstData = TCPPacketFactory.INSTANCE.createRstData( + connection.getLastIpHeader(), connection.getLastTcpHeader(), 0); + + writer.write(rstData); + + //remove session + Log.e(TAG,"failed to write to remote socket, aborting connection"); + connection.setAbortingConnection(true); + } + } + + private void writePendingData(Connection connection) throws IOException { + if (!connection.hasDataToSend()) return; + AbstractSelectableChannel channel = connection.getChannel(); + + byte[] data = connection.getSendingData(); + ByteBuffer buffer = ByteBuffer.allocate(data.length); + buffer.put(data); + buffer.flip(); + + while (buffer.hasRemaining()) { + int bytesWritten = channel instanceof SocketChannel + ? ((SocketChannel) channel).write(buffer) + : ((DatagramChannel) channel).write(buffer); + + if (bytesWritten == 0) { + break; + } + } + + if (buffer.hasRemaining()) { + // The channel's own buffer is full, so we have to save this for later. + Log.i(TAG, buffer.remaining() + " bytes unwritten for " + channel); + + // Put the remaining data from the buffer back into the session + connection.setSendingData(buffer.compact()); + + // Subscribe to WRITE events, so we know when this is ready to resume. + connection.subscribeKey(SelectionKey.OP_WRITE); + } else { + // All done, all good -> wait until the next TCP PSH / UDP packet + connection.setDataForSendingReady(false); + + // We don't need to know about WRITE events any more, we've written all our data. + // This is safe from races with new data, due to the session lock in NIO. + connection.unsubscribeKey(SelectionKey.OP_WRITE); + } + } +} diff --git a/android/app/src/main/kotlin/com/network/proxy/vpn/socket/SocketNIODataService.java b/android/app/src/main/kotlin/com/network/proxy/vpn/socket/SocketNIODataService.java new file mode 100644 index 0000000..4b94d20 --- /dev/null +++ b/android/app/src/main/kotlin/com/network/proxy/vpn/socket/SocketNIODataService.java @@ -0,0 +1,247 @@ +package com.network.proxy.vpn.socket; + +import android.util.Log; + + +import com.network.proxy.vpn.Connection; +import com.network.proxy.vpn.TagKt; + +import java.io.IOException; +import java.nio.channels.ClosedChannelException; +import java.nio.channels.DatagramChannel; +import java.nio.channels.SelectableChannel; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; +import java.nio.channels.SocketChannel; +import java.nio.channels.spi.AbstractSelectableChannel; +import java.util.Iterator; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; + + +/** + * A service that single-threadedly processes the events around our session connections, + * entirely via non-blocking NIO. + *

+ * It uses a Selector that fires on outgoing socket events (connected, readable, writable), + * handles the resulting operations, and keeps those subscriptions up to date. + */ +public class SocketNIODataService implements Runnable { + + private final String TAG = TagKt.getTAG(this); + private final ReentrantLock nioSelectionLock = new ReentrantLock(); + private final ReentrantLock nioHandlingLock = new ReentrantLock(); + private final Selector selector = Selector.open(); + + private final SocketChannelReader reader; + private final SocketChannelWriter writer; + + private volatile boolean shutdown = false; + + + public SocketNIODataService(ClientPacketWriter clientPacketWriter) throws IOException { + reader = new SocketChannelReader(clientPacketWriter); + writer = new SocketChannelWriter(clientPacketWriter); + } + + @Override + public void run() { + Log.d(TAG,"SocketNIODataService starting in background..."); + runTask(); + } + + public void registerSession(Connection connection) throws ClosedChannelException { + AbstractSelectableChannel channel = connection.getChannel(); + + boolean isConnected = channel instanceof DatagramChannel + ? ((DatagramChannel) channel).isConnected() + : ((SocketChannel) channel).isConnected(); + +// Log.i(TAG, "Registering new session: " + session); + + Lock selectorLock = lockSelector(selector); + try { + SelectionKey selectionKey = channel.register(selector, + isConnected + ? SelectionKey.OP_READ + : SelectionKey.OP_CONNECT + ); + connection.setSelectionKey(selectionKey); + selectionKey.attach(connection); +// Log.d(TAG, "Registered selector successfully"); + } finally { + selectorLock.unlock(); + } + } + + private Lock lockSelector(Selector selector) { + boolean gotSelectionLock = nioSelectionLock.tryLock(); + if (gotSelectionLock) return nioSelectionLock; + + nioHandlingLock.lock(); // Ensure the NIO thread can't do anything on wakeup + selector.wakeup(); + + nioSelectionLock.lock(); // Actually get the lock we want + nioHandlingLock.unlock(); // Release the handling lock, which we no longer care about + + return nioSelectionLock; + } + + /** + * If the selector is currently select()ing, wake it up (e.g. to register changes to + * interestOps). If it's not (and so it probably will select() very soon anyway) do nothing. + * This is designed to be run after changing readyOps, to ensure the new ops get monitored + * immediately (and fire immediately, if already ready). Without this, that blocks. + */ + public void refreshSelect(Connection connection) { + boolean gotLock = nioSelectionLock.tryLock(); + + if (!gotLock) { + connection.getSelectionKey().selector().wakeup(); + } else { + nioSelectionLock.unlock(); + } + } + + /** + * Shut down the NIO thread + */ + public void shutdown(){ + this.shutdown = true; + selector.wakeup(); + } + + private void runTask(){ + Log.i(TAG, "NIO selector is running..."); + + while(!shutdown){ + try { + nioSelectionLock.lockInterruptibly(); + selector.select(); + } catch (IOException e) { + Log.e(TAG,"Error in Selector.select(): " + e.getMessage()); + try { + Thread.sleep(100); + } catch (InterruptedException ex) { + Log.e(TAG, e.toString()); + } + continue; + } catch (InterruptedException ex) { + Log.i(TAG, "Select() interrupted"); + } finally { + if (nioSelectionLock.isHeldByCurrentThread()) { + nioSelectionLock.unlock(); + } + } + + if (shutdown) { + break; + } + + // A lock here makes it possible to reliably grab the selection lock above + nioHandlingLock.lock(); + try { + Iterator iterator = selector.selectedKeys().iterator(); + + while (iterator.hasNext()) { + SelectionKey key = iterator.next(); + Connection connection = ((Connection) key.attachment()); + synchronized (connection) { // Sessions are locked during processing (no VPN data races) + try { + processSelectionKey(key); + } catch (IOException e) { + synchronized (key) { + key.cancel(); + } + } + } + + iterator.remove(); + if (shutdown) { + break; + } + } + } finally { + nioHandlingLock.unlock(); + } + } + Log.i(TAG, "NIO selector shutdown"); + } + + private void processSelectionKey(SelectionKey key) throws IOException { + if (!key.isValid()) { + Log.d(TAG,"Invalid SelectionKey"); + return; + } + + SelectableChannel channel = key.channel(); + + Connection connection = ((Connection) key.attachment()); + if (connection == null) { + Log.w(TAG, "Key fired with no session attached"); + return; + } + + if (channel instanceof SocketChannel && !connection.isConnected() && key.isConnectable()) { + SocketChannel socketChannel = (SocketChannel) channel; + + if (socketChannel.isConnectionPending()) { + boolean connected = socketChannel.finishConnect(); + connection.setConnected(connected); + } else { + throw new IllegalStateException("TCP channels must either be connected or pending connection"); + } + } + + if (isConnected(channel)) { + processConnectedSelection(key, connection); + } + } + + private boolean isConnected(SelectableChannel channel) { + if (channel instanceof DatagramChannel) { + return ((DatagramChannel) channel).isConnected(); + } else if (channel instanceof SocketChannel) { + return ((SocketChannel) channel).isConnected(); + } else { + throw new IllegalArgumentException("isConnected on unexpected channel type: " + channel); + } + } + + private void processConnectedSelection(SelectionKey key, Connection connection) { + // Whilst connected, we always want READ and not CONNECT events + connection.unsubscribeKey(SelectionKey.OP_CONNECT); + connection.subscribeKey(SelectionKey.OP_READ); + processSelectorRead(key, connection); + processPendingWrite(key, connection); + } + + private void processSelectorRead(SelectionKey selectionKey, Connection connection) { + boolean canRead; + synchronized (selectionKey) { + // There's a race here that requires a lock, as isReadable requires isValid + canRead = selectionKey.isValid() && selectionKey.isReadable(); + } + + if (canRead) reader.read(connection); + } + + private void processPendingWrite(SelectionKey selectionKey, Connection connection) { + // Nothing to write? Skip this entirely, and make sure we're not subscribed + if (!connection.hasDataToSend() || !connection.isDataForSendingReady()) { + connection.unsubscribeKey(SelectionKey.OP_WRITE); + return; + } + + boolean canWrite; + synchronized (selectionKey) { + // There's a race here that requires a lock, as isReadable requires isValid + canWrite = selectionKey.isValid() && selectionKey.isWritable(); + } + + if (canWrite) { + connection.unsubscribeKey(SelectionKey.OP_WRITE); + writer.write(connection); // This will resubscribe to OP_WRITE if it can't complete + } + } +} diff --git a/android/app/src/main/kotlin/com/network/proxy/vpn/transport/Packet.kt b/android/app/src/main/kotlin/com/network/proxy/vpn/transport/Packet.kt new file mode 100644 index 0000000..0546c67 --- /dev/null +++ b/android/app/src/main/kotlin/com/network/proxy/vpn/transport/Packet.kt @@ -0,0 +1,7 @@ +package com.network.proxy.vpn.transport + +import com.network.proxy.vpn.transport.protocol.IP4Header +import com.network.proxy.vpn.transport.protocol.TransportHeader + +class Packet(var ipHeader: IP4Header, var transportHeader: TransportHeader, var buffer: ByteArray) { +} \ No newline at end of file diff --git a/android/app/src/main/kotlin/com/network/proxy/vpn/transport/icmp/ICMPPacket.java b/android/app/src/main/kotlin/com/network/proxy/vpn/transport/icmp/ICMPPacket.java new file mode 100644 index 0000000..fb9306b --- /dev/null +++ b/android/app/src/main/kotlin/com/network/proxy/vpn/transport/icmp/ICMPPacket.java @@ -0,0 +1,46 @@ +package com.network.proxy.vpn.transport.icmp; + + +import androidx.annotation.NonNull; + +public class ICMPPacket { + // Two ICMP packets we can handle: simple ping & pong + public static final byte ECHO_REQUEST_TYPE = 8; + public static final byte ECHO_SUCCESS_TYPE = 0; + + // One very common packet we ignore: connection rejection. Unclear why this happens, + // random incoming connections that the phone tries to reply to? Nothing we can do though, + // as we can't forward ICMP onwards, and we can't usefully respond or react. + public static final byte DESTINATION_UNREACHABLE_TYPE = 3; + + public final byte type; + final byte code; // 0 for request, 0 for success, 0 - 15 for error subtypes + + final int checksum; + final int identifier; + final int sequenceNumber; + + final byte[] data; + + ICMPPacket( + int type, + int code, + int checksum, + int identifier, + int sequenceNumber, + byte[] data + ) { + this.type = (byte) type; + this.code = (byte) code; + this.checksum = checksum; + this.identifier = identifier; + this.sequenceNumber = sequenceNumber; + this.data = data; + } + + @NonNull + public String toString() { + return "ICMP packet type " + type + "/" + code + " id:" + identifier + + " seq:" + sequenceNumber + " and " + data.length + " bytes of data"; + } +} diff --git a/android/app/src/main/kotlin/com/network/proxy/vpn/transport/icmp/ICMPPacketFactory.java b/android/app/src/main/kotlin/com/network/proxy/vpn/transport/icmp/ICMPPacketFactory.java new file mode 100644 index 0000000..3ae467e --- /dev/null +++ b/android/app/src/main/kotlin/com/network/proxy/vpn/transport/icmp/ICMPPacketFactory.java @@ -0,0 +1,80 @@ +package com.network.proxy.vpn.transport.icmp; + + +import androidx.annotation.NonNull; + +import com.network.proxy.vpn.transport.protocol.IP4Header; +import com.network.proxy.vpn.util.PacketUtil; + +import java.io.ByteArrayOutputStream; +import java.nio.ByteBuffer; + + +public class ICMPPacketFactory { + + public static ICMPPacket parseICMPPacket(@NonNull ByteBuffer stream) { + final byte type = stream.get(); + final byte code = stream.get(); + final int checksum = stream.getShort(); + + final int identifier = stream.getShort(); + final int sequenceNumber = stream.getShort(); + + final byte[] data = new byte[stream.remaining()]; + stream.get(data); + + return new ICMPPacket(type, code, checksum, identifier, sequenceNumber, data); + } + + public static ICMPPacket buildSuccessPacket(ICMPPacket requestPacket) { + return new ICMPPacket( + 0, + 0, + 0, + requestPacket.identifier, + requestPacket.sequenceNumber, + requestPacket.data + ); + } + + public static byte[] packetToBuffer(IP4Header ipHeader, ICMPPacket packet) { + byte[] ipData = ipHeader.toBytes(); + + ByteArrayOutputStream icmpDataBuffer = new ByteArrayOutputStream(); + icmpDataBuffer.write(packet.type); + icmpDataBuffer.write(packet.code); + + icmpDataBuffer.write(asShortBytes(0 /* checksum placeholder */), 0, 2); + + if (packet.type == ICMPPacket.ECHO_REQUEST_TYPE || packet.type == ICMPPacket.ECHO_SUCCESS_TYPE) { + icmpDataBuffer.write(asShortBytes(packet.identifier), 0, 2); + icmpDataBuffer.write(asShortBytes(packet.sequenceNumber), 0, 2); + + byte[] extraData = packet.data; + icmpDataBuffer.write(extraData, 0, extraData.length); + } else { + throw new IllegalArgumentException("Can't serialize unrecognized ICMP packet type"); + } + + byte[] icmpPacketData = icmpDataBuffer.toByteArray(); + byte[] checksum = PacketUtil.INSTANCE.calculateChecksum(icmpPacketData, 0, icmpPacketData.length); + + ByteBuffer resultBuffer = ByteBuffer.allocate(ipData.length + icmpPacketData.length); + resultBuffer.put(ipData); + resultBuffer.put(icmpPacketData); + + // Replace the checksum placeholder + resultBuffer.position(ipData.length + 2); + resultBuffer.put(checksum); + resultBuffer.position(0); + + byte[] result = new byte[resultBuffer.remaining()]; + resultBuffer.get(result); + return result; + } + + private static byte[] asShortBytes(int value) { + return ByteBuffer.allocate(2).putShort((short) value).array(); + } + +} diff --git a/android/app/src/main/kotlin/com/network/proxy/vpn/transport/protocol/IP4Header.kt b/android/app/src/main/kotlin/com/network/proxy/vpn/transport/protocol/IP4Header.kt new file mode 100644 index 0000000..7c3ac00 --- /dev/null +++ b/android/app/src/main/kotlin/com/network/proxy/vpn/transport/protocol/IP4Header.kt @@ -0,0 +1,142 @@ +package com.network.proxy.vpn.transport.protocol + +import java.nio.ByteBuffer +import java.nio.ByteOrder + +/** + * IPv4报头的数据结构。 + */ +data class IP4Header( + var ipVersion: Byte = 0, //对于IPv4,其值为4(因此命名为IPv4)。 4bit + private var internetHeaderLength: Byte = 0, //头部长度 4bit + private var diffTypeOfService: Byte, //差分服务代码点 =>6位 + private var ecn: Byte = 0, //显式拥塞通知(ECN) + var totalLength: Int = 0, //此IP数据包的总长度 16bit + var identification: Int = 0, //主要用于唯一标识单个IP数据报的片段组。 16bit + private var mayFragment: Boolean, // 1bit 用于指示数据报是否可以分段。 + private var lastFragment: Boolean, // 1bit 用于指示数据报是否是片段中的最后一个。 + var fragmentOffset: Short = 0, //13bit,指定特定片段相对于原始未分段的IP数据报的开始的偏移量。 + private var timeToLive: Byte = 0, //用于防止数据报持续存在。8bit + var protocol: Byte = 0, //定义IP数据报的数据部分中使用的协议。 8bit + var headerChecksum: Int = 0, //用于对头部进行错误检查的16位字段。 16bit + var sourceIP: Int = 0, //发送者的IPv4地址。 32bit + var destinationIP: Int = 0 //接收者的IPv4地址。 32bit +) { + //用于控制或识别片段的3比特字段。 + //bit 0: 保留;必须为零 + //bit 1: Don't Fragment (DF) + //bit 2: More Fragments (MF) + private var flag: Byte = initFlag() + + private fun initFlag(): Byte { + var initFlag = 0 + if (mayFragment) { + initFlag = 0x40 + } + + if (lastFragment) { + initFlag = (initFlag or 0x20) + } + return initFlag.toByte() + } + + fun setMayFragment(mayFragment: Boolean) { + this.mayFragment = mayFragment + flag = if (mayFragment) { + (flag.toInt() or 0x40).toByte() + } else { + (flag.toInt() and 0xBF).toByte() + } + } + + fun getIPHeaderLength(): Int { + return internetHeaderLength * 4 + } + + fun copy(): IP4Header { + return IP4Header( + ipVersion, internetHeaderLength, diffTypeOfService, ecn, totalLength, identification, + mayFragment, lastFragment, fragmentOffset, timeToLive, protocol, headerChecksum, + sourceIP, destinationIP + ) + } + + fun toBytes(): ByteArray { + val buffer = ByteBuffer.allocate(getIPHeaderLength()) + buffer.order(ByteOrder.BIG_ENDIAN) + val versionAndHeaderLength = (ipVersion.toInt() shl 4) + internetHeaderLength + buffer.put(versionAndHeaderLength.toByte()) + + val typeOfService: Byte = (diffTypeOfService.toInt() shl 2 and (ecn + .toInt() and 0xFF)).toByte() + buffer.put(typeOfService) + + buffer.putShort(totalLength.toShort()) + buffer.putShort(identification.toShort()) + + //组合标志和部分片段偏移 + buffer.put((fragmentOffset.toInt() shr 8 and 0x1F or flag.toInt()).toByte()) + buffer.put(fragmentOffset.toByte()) + + buffer.put(timeToLive) + buffer.put(protocol) + buffer.putShort(headerChecksum.toShort()) + buffer.putInt(sourceIP) + buffer.putInt(destinationIP) + return buffer.array() + } +} + +object IPPacketFactory { + private const val IP4_HEADER_SIZE = 20 + private const val IP4_VERSION = 0x04 + + /** + * 从给定的ByteBuffer流创建IPv4标头 + */ + fun createIP4Header(buffer: ByteBuffer): IP4Header? { + if (buffer.remaining() < IP4_HEADER_SIZE) { + throw IllegalArgumentException("IP header byte array must have at least $IP4_HEADER_SIZE bytes") + } + + val versionAndHeaderLength: Byte = buffer.get() + val ipVersion = (versionAndHeaderLength.toInt() shr 4).toByte() + if (ipVersion.toInt() != IP4_VERSION) { +// throw IllegalArgumentException("Invalid IP version $ipVersion") + return null + } + + val internetHeaderLength = (versionAndHeaderLength.toInt() and 0x0F).toByte() + + val typeOfService = buffer.get().toInt() + val diffTypeOfService: Byte = (typeOfService shr 2).toByte(); + val ecn: Byte = (typeOfService and 0x03).toByte() + + val totalLength: Int = buffer.getShort().toInt() + val identification: Int = buffer.getShort().toInt() + + val flagsAndFragmentOffset: Short = buffer.getShort() + val mayFragment = flagsAndFragmentOffset.toInt() and 0x4000 != 0 + val lastFragment = flagsAndFragmentOffset.toInt() and 0x2000 != 0 + val fragmentOffset = (flagsAndFragmentOffset.toInt() and 0x1FFF).toShort() + + val timeToLive: Byte = buffer.get() + val protocol: Byte = buffer.get() + val checksum: Int = buffer.getShort().toInt() + val sourceIp: Int = buffer.getInt() + val desIp: Int = buffer.getInt() + + if (internetHeaderLength > 5) { + // drop the IP option + for (i in 0 until (internetHeaderLength - 5)) { + buffer.getInt() + } + } + + return IP4Header( + ipVersion, internetHeaderLength, diffTypeOfService, ecn, totalLength, identification, + mayFragment, lastFragment, fragmentOffset, timeToLive, protocol, checksum, + sourceIp, desIp + ) + } +} \ No newline at end of file diff --git a/android/app/src/main/kotlin/com/network/proxy/vpn/transport/protocol/TCPHeader.kt b/android/app/src/main/kotlin/com/network/proxy/vpn/transport/protocol/TCPHeader.kt new file mode 100644 index 0000000..7e196c9 --- /dev/null +++ b/android/app/src/main/kotlin/com/network/proxy/vpn/transport/protocol/TCPHeader.kt @@ -0,0 +1,212 @@ +package com.network.proxy.vpn.transport.protocol + +import java.nio.ByteBuffer +import java.nio.ByteOrder + +/** + * TCP报头的数据结构。 + */ +class TCPHeader( + private var sourcePort: Int = 0, //源端口号 16bit + private var destinationPort: Int = 0, //目的端口号 16bit + var sequenceNumber: Long = 0, //序列号 32bit + var ackNumber: Long = 0, //确认号 32bit + var dataOffset: Int = 0, //数据偏移4bit + var isNS: Boolean = false, //ECN-nonce concealment protection (experimental: see RFC 3540) + var flags: Int = 0, //标志位 9bit + var windowSize: Int = 0, //窗口大小 16bit + var checksum: Int = 0, //校验和 16bit + private var urgentPointer: Int = 0, //紧急指针 16bit + var options: ByteArray? = null //选项 +) : TransportHeader { + + //options + var maxSegmentSize: Short = 0 + private var windowScale: Byte = 0 + private var isSelectiveAckPermitted = false + var timeStampSender = 0 + var timeStampReplyTo = 0 + + companion object { + private const val END_OF_OPTIONS_LIST: Byte = 0 + private const val NO_OPERATION: Byte = 1 + private const val MAX_SEGMENT_SIZE: Byte = 2 + private const val WINDOW_SCALE: Byte = 3 + private const val SELECTIVE_ACK_PERMITTED: Byte = 4 + private const val TIME_STAMP: Byte = 8 + } + + fun isSYN(): Boolean { + return flags and 0x02 != 0 + } + + fun isFIN(): Boolean { + return flags and 0x01 != 0 + } + + fun isRST(): Boolean { + return flags and 0x04 != 0 + } + + fun isPSH(): Boolean { + return flags and 0x08 != 0 + } + + fun isACK(): Boolean { + return flags and 0x10 != 0 + } + + fun isURG(): Boolean { + return flags and 0x20 != 0 + } + + fun isECE(): Boolean { + return flags and 0x40 != 0 + } + + fun isCWR(): Boolean { + return flags and 0x80 != 0 + } + + fun setIsRST(isRST: Boolean) { + flags = if (isRST) { + (flags or 0x04) + } else { + (flags and 0xFB) + } + } + + fun setIsSYN(isSYN: Boolean) { + flags = if (isSYN) { + (flags or 0x02) + } else { + (flags and 0xFD) + } + } + + fun setIsFIN(isFIN: Boolean) { + flags = if (isFIN) { + (flags or 0x01) + } else { + (flags and 0xFE) + } + } + + fun setIsPSH(isPSH: Boolean) { + flags = if (isPSH) { + (flags or 0x08) + } else { + (flags and 0xF7) + } + } + + fun setIsACK(isACK: Boolean) { + flags = if (isACK) { + (flags or 0x10) + } else { + (flags and 0xEF) + } + } + + fun getTCPHeaderLength(): Int { + return dataOffset * 4 + } + + fun toBytes(): ByteArray { + val tcpHeaderLength = getTCPHeaderLength() + val tcpHeader = ByteArray(tcpHeaderLength) + val byteBuffer = ByteBuffer.wrap(tcpHeader) + byteBuffer.order(ByteOrder.BIG_ENDIAN) + + byteBuffer.putShort(sourcePort.toShort()) + byteBuffer.putShort(destinationPort.toShort()) + + byteBuffer.putInt(sequenceNumber.toInt()) + byteBuffer.putInt(ackNumber.toInt()) + + //is ns and data offset + byteBuffer.put(((dataOffset shl 4) and 0xF0 or (if (isNS) 0x1 else 0x0)).toByte()) + byteBuffer.put(flags.toByte()) + byteBuffer.putShort(windowSize.toShort()) + byteBuffer.putShort(checksum.toShort()) + byteBuffer.putShort(urgentPointer.toShort()) +// encodeTcpOptions()?.let { +// byteBuffer.put(it) +// } + + return tcpHeader + } + + fun copy(): TCPHeader { + return TCPHeader( + sourcePort, destinationPort, sequenceNumber, ackNumber, + dataOffset, isNS, flags, windowSize, checksum, urgentPointer, + options + ) + } + + private fun handleTcpOptions() { + if (options == null) { + return + } + + var index = 0 + val packet = ByteBuffer.wrap(options!!) + val optionsSize = options!!.size + + while (index < optionsSize) { + val optionKind = packet.get() + index++ + if (optionKind == END_OF_OPTIONS_LIST || optionKind == NO_OPERATION) { + continue + } + val size = packet.get() + index++ + when (optionKind) { + MAX_SEGMENT_SIZE -> { + maxSegmentSize = packet.getShort() + index += 2 + } + + WINDOW_SCALE -> { + windowScale = packet.get() + index++ + } + + SELECTIVE_ACK_PERMITTED -> isSelectiveAckPermitted = true + TIME_STAMP -> { + timeStampSender = packet.getInt() + timeStampReplyTo = packet.getInt() + index += 8 + } + + else -> { + skipRemainingOptions(packet, size.toInt()) + index = index + size - 2 + } + } + } + } + + private fun skipRemainingOptions(packet: ByteBuffer, size: Int) { + for (i in 2 until size) { + packet.get() + } + } + + override fun getSourcePort(): Int { + return sourcePort + } + + override fun getDestinationPort(): Int { + return destinationPort + } + + fun setSourcePort(sourcePort: Int) { + this.sourcePort = sourcePort + } + + fun setDestinationPort(destinationPort: Int) { + this.destinationPort = destinationPort + } +} \ No newline at end of file diff --git a/android/app/src/main/kotlin/com/network/proxy/vpn/transport/protocol/TCPPacketFactory.kt b/android/app/src/main/kotlin/com/network/proxy/vpn/transport/protocol/TCPPacketFactory.kt new file mode 100644 index 0000000..d981d7c --- /dev/null +++ b/android/app/src/main/kotlin/com/network/proxy/vpn/transport/protocol/TCPPacketFactory.kt @@ -0,0 +1,285 @@ +package com.network.proxy.vpn.transport.protocol + +import com.network.proxy.vpn.transport.Packet +import com.network.proxy.vpn.util.PacketUtil +import java.nio.ByteBuffer +import java.util.concurrent.ThreadLocalRandom + +object TCPPacketFactory { + + private const val TCP_HEADER_LENGTH = 20 + + /** + * 从tcp报文创建tcpHeader + */ + @JvmStatic + fun createTCPHeader(byteBuffer: ByteBuffer): TCPHeader { + if (byteBuffer.remaining() < TCP_HEADER_LENGTH) { + throw IllegalArgumentException("Invalid TCP Header Length") + } + + val sourcePort: Int = byteBuffer.getShort().toInt() and 0xFFFF + val destinationPort: Int = byteBuffer.getShort().toInt() and 0xFFFF + val sequenceNumber: Long = byteBuffer.getInt().toLong() + val ackNumber: Long = byteBuffer.getInt().toLong() + + val dataOffsetAndReserved = byteBuffer.get() + val dataOffset = (dataOffsetAndReserved.toInt() and 0xF0) shr 4 + val isNs: Boolean = dataOffsetAndReserved.toInt() and 0x1 > 0x0 + + val flags = byteBuffer.get().toInt() + + val window = byteBuffer.short.toInt() + val checksum = byteBuffer.short.toInt() + val urgentPointer = byteBuffer.short.toInt() + + var optionsAndPadding: ByteArray? = null + val optionsSize = dataOffset - 5 + if (optionsSize > 0) { + optionsAndPadding = ByteArray(optionsSize * 4) + byteBuffer.get(optionsAndPadding, 0, optionsSize * 4) + } + return TCPHeader( + sourcePort, destinationPort, sequenceNumber, ackNumber, + dataOffset, isNs, flags, window, checksum, urgentPointer, optionsAndPadding + ) + } + + /** + * 创建带有RST标志的数据包,以便在需要重置时发送到客户端。 + */ + fun createRstData(ipHeader: IP4Header, tcpHeader: TCPHeader, dataLength: Int): ByteArray { + val ip = ipHeader.copy() + val tcp = tcpHeader.copy() + + var ackNumber: Long = 0 + var seqNumber: Long = 0 + + if (tcp.ackNumber > 0) { + seqNumber = tcp.ackNumber + } else { + ackNumber = tcp.sequenceNumber + dataLength + } + + tcp.ackNumber = ackNumber + tcp.sequenceNumber = seqNumber + + //将IP从源翻转到目标 + flipIp(ip, tcp) + + ip.identification = 0 + + tcp.flags = 0 + tcp.isNS = false + tcp.setIsRST(true) + + tcp.options = null + tcp.windowSize = 0 + + //重新计算IP长度 + val totalLength = ip.getIPHeaderLength() + tcp.getTCPHeaderLength() + + ip.totalLength = totalLength + + return createPacketData(ip, tcp, null) + } + + /** + * 创建数据包数据以发送回客户端 + */ + @JvmStatic + fun createResponsePacketData( + ipHeader: IP4Header, tcpHeader: TCPHeader, packetData: ByteArray?, isPsh: Boolean, + ackNumber: Long, seqNumber: Long, timeSender: Int, timeReplyTo: Int + ): ByteArray { + val ip = ipHeader.copy() + val tcp = tcpHeader.copy() + + flipIp(ip, tcp) + tcp.ackNumber = ackNumber + tcp.sequenceNumber = seqNumber + ip.identification = PacketUtil.getPacketId() + + //总是发送ACK + + //ACK is always sent + tcp.setIsACK(true) + tcp.setIsSYN(false) + tcp.setIsPSH(isPsh) + tcp.setIsFIN(false) + tcp.timeStampSender = timeSender + tcp.timeStampReplyTo = timeReplyTo + + var totalLength = ip.getIPHeaderLength() + tcp.getTCPHeaderLength() + if (packetData != null) { + totalLength += packetData.size + } + ip.totalLength = totalLength + + return createPacketData(ip, tcp, packetData) + } + + + /** + * 向客户端确认服务器已收到请求。 + */ + @JvmStatic + fun createResponseAckData( + ipHeader: IP4Header, tcpHeader: TCPHeader, ackToClient: Long + ): ByteArray { + val ip = ipHeader.copy() + val tcp = tcpHeader.copy() + + flipIp(ip, tcp) + val seqNumber = tcp.ackNumber + tcp.ackNumber = ackToClient + tcp.sequenceNumber = seqNumber + + ip.identification = PacketUtil.getPacketId() + + //ACK + tcp.setIsACK(true) + tcp.setIsSYN(false) + tcp.setIsPSH(false) + tcp.setIsFIN(false) + + ip.totalLength = ip.getIPHeaderLength() + tcp.getTCPHeaderLength() + return createPacketData(ip, tcp, null) + } + + //将IP从源翻转到目标 + private fun flipIp(ip: IP4Header, tcp: TCPHeader) { + val sourceIp = ip.destinationIP + val destIp = ip.sourceIP + val sourcePort = tcp.getDestinationPort() + val destPort = tcp.getSourcePort() + + ip.destinationIP = destIp + ip.sourceIP = sourceIp + tcp.setDestinationPort(destPort) + tcp.setSourcePort(sourcePort) + } + + /** + * 通过写回客户端流创建SYN-ACK数据包数据 + */ + fun createSynAckPacketData(ipHeader: IP4Header, tcpHeader: TCPHeader): Packet { + val ip = ipHeader.copy() + val tcp = tcpHeader.copy() + + flipIp(ip, tcp) + + //ack = received sequence + 1 + val ackNumber = tcpHeader.sequenceNumber + 1 + tcp.ackNumber = ackNumber + + //服务器生成的初始序列号 + val seqNumber = ThreadLocalRandom.current().nextLong(0, 100000) + tcp.sequenceNumber = seqNumber + + //SYN-ACK + tcp.setIsACK(true) + tcp.setIsSYN(true) + + tcp.timeStampReplyTo = tcp.timeStampSender + tcp.timeStampSender = PacketUtil.currentTime + + return Packet(ip, tcp, createPacketData(ip, tcp, null)) + } + + /** + * 创建发送到客户端的FIN-ACK + */ + fun createFinAckData( + ipHeader: IP4Header, tcpHeader: TCPHeader, ackToClient: Long, + seqToClient: Long, isFin: Boolean, isAck: Boolean + ): ByteArray { + val ip = ipHeader.copy() + val tcp = tcpHeader.copy() + + flipIp(ip, tcp) + + tcp.ackNumber = ackToClient + tcp.sequenceNumber = seqToClient + ip.identification = PacketUtil.getPacketId() + + //ACK + tcp.setIsACK(isAck) + tcp.setIsSYN(false) + tcp.setIsPSH(false) + tcp.setIsFIN(isFin) + + ip.totalLength = ip.getIPHeaderLength() + tcp.getTCPHeaderLength() + return createPacketData(ip, tcp, null) + } + + fun createFinData( + ip: IP4Header, tcp: TCPHeader, ackNumber: Long, seqNumber: Long, + timeSender: Int, timeReplyTo: Int + ): ByteArray { + //将IP从源翻转到目标 + flipIp(ip, tcp) + + tcp.ackNumber = ackNumber + tcp.sequenceNumber = seqNumber + + ip.identification = PacketUtil.getPacketId() + + tcp.timeStampReplyTo = timeReplyTo + tcp.timeStampSender = timeSender + + tcp.flags = 0 + tcp.isNS = false + tcp.setIsACK(true) + tcp.setIsFIN(true) + + tcp.options = null + //窗口大小应为零 + tcp.windowSize = 0 + + ip.totalLength = ip.getIPHeaderLength() + tcp.getTCPHeaderLength() + return createPacketData(ip, tcp, null) + } + + /** + * 从tcpHeader创建tcp报文 + */ + private fun createPacketData(ipHeader: IP4Header, tcpHeader: TCPHeader, data: ByteArray?): + ByteArray { + val dataLength = data?.size ?: 0 + + val buffer = + ByteBuffer.allocate(ipHeader.getIPHeaderLength() + tcpHeader.getTCPHeaderLength() + dataLength) + val ipBuffer = ipHeader.toBytes() + val tcpBuffer = tcpHeader.toBytes() + + buffer.put(ipBuffer) + buffer.put(tcpBuffer) + + data?.let { buffer.put(it) } + + val zero = byteArrayOf(0, 0) + //计算前先将校验和清零 + buffer.position(10) + buffer.put(zero) + + val ipChecksum = PacketUtil.calculateChecksum(buffer.array(), 0, ipBuffer.size) + buffer.position(10) + buffer.put(ipChecksum) + + val tcpStart = ipBuffer.size + buffer.position(tcpStart + 16) + buffer.put(zero) + + val tcpChecksum = PacketUtil.calculateTCPHeaderChecksum( + buffer.array(), tcpStart, tcpBuffer.size + dataLength, + ipHeader.destinationIP, ipHeader.sourceIP + ) + + //将新的校验和写回阵列 + buffer.position(tcpStart + 16) + buffer.put(tcpChecksum) + return buffer.array() + } + +} diff --git a/android/app/src/main/kotlin/com/network/proxy/vpn/transport/protocol/TransportHeader.kt b/android/app/src/main/kotlin/com/network/proxy/vpn/transport/protocol/TransportHeader.kt new file mode 100644 index 0000000..34829db --- /dev/null +++ b/android/app/src/main/kotlin/com/network/proxy/vpn/transport/protocol/TransportHeader.kt @@ -0,0 +1,6 @@ +package com.network.proxy.vpn.transport.protocol + +interface TransportHeader { + fun getSourcePort(): Int + fun getDestinationPort(): Int +} \ No newline at end of file diff --git a/android/app/src/main/kotlin/com/network/proxy/vpn/transport/protocol/UDPHeader.kt b/android/app/src/main/kotlin/com/network/proxy/vpn/transport/protocol/UDPHeader.kt new file mode 100644 index 0000000..6ddc5dc --- /dev/null +++ b/android/app/src/main/kotlin/com/network/proxy/vpn/transport/protocol/UDPHeader.kt @@ -0,0 +1,89 @@ +package com.network.proxy.vpn.transport.protocol + + +import com.network.proxy.vpn.util.PacketUtil +import java.nio.ByteBuffer + + +/** + * UDP报头的数据结构。 + */ +data class UDPHeader( + var sourcePort: Int = 0, //源端口号 16bit + var destinationPort: Int = 0, //目的端口号 16bit + var length: Int = 0, //UDP数据报长度 16bit + var checksum: Int = 0 //校验和 16bit +) + + +object UDPPacketFactory { + @JvmStatic + fun createUDPHeader(stream: ByteBuffer): UDPHeader { + require(stream.remaining() >= 8) { "Minimum UDP header is 8 bytes." } + val srcPort = stream.getShort().toInt() and 0xffff + val destPort = stream.getShort().toInt() and 0xffff + val length = stream.getShort().toInt() and 0xffff + val checksum = stream.getShort().toInt() + return UDPHeader(srcPort, destPort, length, checksum) + } + + /** + * 创建用于响应vpn客户端的数据包 + */ + @JvmStatic + fun createResponsePacket(ip: IP4Header, udp: UDPHeader, packetData: ByteArray?): ByteArray { + val buffer: ByteArray + var udpLen = 8 + if (packetData != null) { + udpLen += packetData.size + } + val srcPort = udp.destinationPort + val destPort = udp.sourcePort + val checksum: Short = 0 + val ipHeader = ip.copy() + val srcIp = ip.destinationIP + val destIp = ip.sourceIP + ipHeader.setMayFragment(false) + ipHeader.sourceIP = srcIp + ipHeader.destinationIP = destIp + ipHeader.identification = PacketUtil.getPacketId() + + //ip的长度是整个数据包的长度 => IP header length + UDP header length (8) + UDP body length + val totalLength = ipHeader.getIPHeaderLength() + udpLen + ipHeader.totalLength = totalLength + buffer = ByteArray(totalLength) + val ipData = ipHeader.toBytes() + + // clear IP checksum + ipData[11] = 0 + ipData[10] = 0 + + //calculate checksum for IP header + val ipChecksum = PacketUtil.calculateChecksum(ipData, 0, ipData.size) + //write result of checksum back to buffer + System.arraycopy(ipChecksum, 0, ipData, 10, 2) + System.arraycopy(ipData, 0, buffer, 0, ipData.size) + + //copy UDP header to buffer + var start = ipData.size + val intContainer = ByteArray(4) + PacketUtil.writeIntToBytes(srcPort, intContainer, 0) + //extract the last two bytes of int value + System.arraycopy(intContainer, 2, buffer, start, 2) + start += 2 + PacketUtil.writeIntToBytes(destPort, intContainer, 0) + System.arraycopy(intContainer, 2, buffer, start, 2) + start += 2 + PacketUtil.writeIntToBytes(udpLen, intContainer, 0) + System.arraycopy(intContainer, 2, buffer, start, 2) + start += 2 + PacketUtil.writeIntToBytes(checksum.toInt(), intContainer, 0) + System.arraycopy(intContainer, 2, buffer, start, 2) + start += 2 + + //now copy udp data + if (packetData != null) System.arraycopy(packetData, 0, buffer, start, packetData.size) + return buffer + } +} + diff --git a/android/app/src/main/kotlin/com/network/proxy/vpn/util/PacketUtil.kt b/android/app/src/main/kotlin/com/network/proxy/vpn/util/PacketUtil.kt new file mode 100644 index 0000000..bc7f831 --- /dev/null +++ b/android/app/src/main/kotlin/com/network/proxy/vpn/util/PacketUtil.kt @@ -0,0 +1,265 @@ +package com.network.proxy.vpn.util + +import android.util.Log +import com.network.proxy.vpn.formatTag +import com.network.proxy.vpn.transport.protocol.IP4Header +import com.network.proxy.vpn.transport.protocol.TCPHeader +import java.nio.ByteBuffer +import java.nio.ByteOrder + +/** + * Helper class to perform various useful task + * + * @author Borey Sao + * Date: May 8, 2014 + */ +object PacketUtil { + @get:Synchronized + private var packetId = 0 + fun getPacketId() = packetId++ + + val currentTime: Int + get() = (System.currentTimeMillis() / 1000).toInt() + + /** + * convert int to byte array + * [...](https://docs.oracle.com/javase/tutorial/java/nutsandbolts/datatypes.html) + * + * @param value int value 32 bits + * @param buffer array of byte to write to + * @param offset position to write to + */ + fun writeIntToBytes(value: Int, buffer: ByteArray, offset: Int) { + if (buffer.size - offset < 4) { + return + } + buffer[offset] = (value ushr 24 and 0x000000FF).toByte() + buffer[offset + 1] = (value shr 16 and 0x000000FF).toByte() + buffer[offset + 2] = (value shr 8 and 0x000000FF).toByte() + buffer[offset + 3] = (value and 0x000000FF).toByte() + } + + /** + * convert array of max 4 bytes to int + * + * @param buffer byte array + * @param start Starting point to be read in byte array + * @param length Length to be read + * @return value of int + */ + fun getNetworkInt(buffer: ByteArray, start: Int, length: Int): Int { + var value = 0 + var end = start + Math.min(length, 4) + if (end > buffer.size) end = buffer.size + for (i in start until end) { + value = value or (buffer[i].toInt() and 0xFF) + if (i < end - 1) value = value shl 8 + } + return value + } + + /** + * validate TCP header checksum + * + * @param source Source Port + * @param destination Destination Port + * @param data Payload + * @param tcpLength TCP Header length + * @return boolean + */ + fun isValidTCPChecksum( + source: Int, destination: Int, + data: ByteArray, tcpLength: Short, tcpOffset: Int + ): Boolean { + var buffersize = tcpLength + 12 + var isodd = false + if (buffersize % 2 != 0) { + buffersize++ + isodd = true + } + val buffer = ByteBuffer.allocate(buffersize) + buffer.putInt(source) + buffer.putInt(destination) + buffer.put(0.toByte()) //reserved => 0 + buffer.put(6.toByte()) //TCP protocol => 6 + buffer.putShort(tcpLength) + buffer.put(data, tcpOffset, tcpLength.toInt()) + if (isodd) { + buffer.put(0.toByte()) + } + return isValidIPChecksum(buffer.array(), buffersize) + } + + /** + * validate IP Header checksum + * + * @param data byte stream + * @return boolean + */ + private fun isValidIPChecksum(data: ByteArray, length: Int): Boolean { + var start = 0 + var sum = 0 + while (start < length) { + sum += getNetworkInt(data, start, 2) + start = start + 2 + } + + //carry over one's complement + while (sum shr 16 > 0) sum = (sum and 0xffff) + (sum shr 16) + + //flip the bit to get one' complement + sum = sum.inv() + val buffer = ByteBuffer.allocate(4) + buffer.putInt(sum) + return buffer.getShort(2).toInt() == 0 + } + + fun calculateChecksum(data: ByteArray, offset: Int, length: Int): ByteArray { + var start = offset + var sum = 0 + while (start < length) { + sum += getNetworkInt(data, start, 2) + start = start + 2 + } + //carry over one's complement + while (sum shr 16 > 0) { + sum = (sum and 0xffff) + (sum shr 16) + } + //flip the bit to get one' complement + sum = sum.inv() + + //extract the last two byte of int + val checksum = ByteArray(2) + checksum[0] = (sum shr 8).toByte() + checksum[1] = sum.toByte() + return checksum + } + + fun calculateTCPHeaderChecksum( + data: ByteArray, + offset: Int, + tcplength: Int, + destip: Int, + sourceip: Int + ): ByteArray { + var buffersize = tcplength + 12 + var odd = false + if (buffersize % 2 != 0) { + buffersize++ + odd = true + } + val buffer = ByteBuffer.allocate(buffersize) + buffer.order(ByteOrder.BIG_ENDIAN) + + //create virtual header + buffer.putInt(sourceip) + buffer.putInt(destip) + buffer.put(0.toByte()) //reserved => 0 + buffer.put(6.toByte()) //tcp protocol => 6 + buffer.putShort(tcplength.toShort()) + + //add actual header + data + buffer.put(data, offset, tcplength) + + //padding last byte to zero + if (odd) { + buffer.put(0.toByte()) + } + val tcparray = buffer.array() + return calculateChecksum(tcparray, 0, buffersize) + } + + fun intToIPAddress(addressInt: Int): String { + return (addressInt ushr 24 and 0x000000FF).toString() + "." + + (addressInt ushr 16 and 0x000000FF) + "." + + (addressInt ushr 8 and 0x000000FF) + "." + + (addressInt and 0x000000FF) + } + + fun getOutput( + ipHeader: IP4Header, tcpheader: TCPHeader, + packetData: ByteArray + ): String { + val tcpLength = (packetData.size - + ipHeader.getIPHeaderLength()).toShort() + val isValidChecksum = isValidTCPChecksum( + ipHeader.sourceIP, ipHeader.destinationIP, + packetData, tcpLength, ipHeader.getIPHeaderLength() + ) + val isValidIPChecksum = isValidIPChecksum( + packetData, + ipHeader.getIPHeaderLength() + ) + val packetBodyLength = (packetData.size - ipHeader.getIPHeaderLength() + - tcpheader.getTCPHeaderLength()) + val str = StringBuilder("\r\nIP Version: ") + .append(ipHeader.ipVersion.toInt()) + .append("\r\nProtocol: ").append(ipHeader.protocol.toInt()) + .append("\r\nID# ").append(ipHeader.identification) + .append("\r\nTotal Length: ").append(ipHeader.totalLength) + .append("\r\nData Length: ").append(packetBodyLength) + .append("\r\nDest: ").append(intToIPAddress(ipHeader.destinationIP)) + .append(":").append(tcpheader.getDestinationPort()) + .append("\r\nSrc: ").append(intToIPAddress(ipHeader.sourceIP)) + .append(":").append(tcpheader.getSourcePort()) + .append("\r\nACK: ").append(tcpheader.ackNumber) + .append("\r\nSeq: ").append(tcpheader.sequenceNumber) + .append("\r\nIP Header length: ").append(ipHeader.getIPHeaderLength()) + .append("\r\nTCP Header length: ").append(tcpheader.getTCPHeaderLength()) + .append("\r\nACK: ").append(tcpheader.isACK()) + .append("\r\nSYN: ").append(tcpheader.isSYN()) + .append("\r\nCWR: ").append(tcpheader.isCWR()) + .append("\r\nECE: ").append(tcpheader.isECE()) + .append("\r\nFIN: ").append(tcpheader.isFIN()) + .append("\r\nNS: ").append(tcpheader.isNS) + .append("\r\nPSH: ").append(tcpheader.isPSH()) + .append("\r\nRST: ").append(tcpheader.isRST()) + .append("\r\nURG: ").append(tcpheader.isURG()) + .append("\r\nIP checksum: ").append(ipHeader.headerChecksum) + .append("\r\nIs Valid IP Checksum: ").append(isValidIPChecksum) + .append("\r\nTCP Checksum: ").append(tcpheader.checksum) + .append("\r\nIs Valid TCP checksum: ").append(isValidChecksum) + .append("\r\nFragment Offset: ").append(ipHeader.fragmentOffset.toInt()) + .append("\r\nWindow: ").append(tcpheader.windowSize) + .append("\r\nData Offset: ").append(tcpheader.dataOffset) + return str.toString() + } + + /** + * detect packet corruption flag in tcp options sent from client ACK + * + * @param tcpHeader TCPHeader + * @return boolean + */ + fun isPacketCorrupted(tcpHeader: TCPHeader): Boolean { + val options = tcpHeader.options + if (options != null) { + var i = 0 + while (i < options.size) { + val kind = options[i] + if (kind.toInt() == 0 || kind.toInt() == 1) { + } else if (kind.toInt() == 2) { + i += 3 + } else if (kind.toInt() == 3 || kind.toInt() == 14) { + i += 2 + } else if (kind.toInt() == 4) { + i++ + } else if (kind.toInt() == 5 || kind.toInt() == 15) { + i = i + options[++i] - 2 + } else if (kind.toInt() == 8) { + i += 9 + } else if (kind.toInt() == 23) { + return true + } else { + Log.e( + formatTag(PacketUtil::class.java.name), + "unknown option: $kind" + ) + } + i++ + } + } + return false + } +} + diff --git a/android/app/src/main/kotlin/com/network/proxy/vpn/util/TLS.kt b/android/app/src/main/kotlin/com/network/proxy/vpn/util/TLS.kt index ec70303..9480aa7 100644 --- a/android/app/src/main/kotlin/com/network/proxy/vpn/util/TLS.kt +++ b/android/app/src/main/kotlin/com/network/proxy/vpn/util/TLS.kt @@ -4,84 +4,82 @@ import java.nio.ByteBuffer import kotlin.math.min -class TLS { +object TLS { - companion object { - /** - * 判断是否是TLS Client Hello - */ - fun isTLSClientHello(packetData: ByteBuffer): Boolean { - if (packetData.remaining() < 43) return false - val position = packetData.position() - val data = packetData.array() - if (data[position].toInt() != 0x16 /* handshake */) return false - if (data[1 + position].toInt() != 0x03) return false - return if (data[5 + position].toInt() != 0x01) false else data[9 + position].toInt() == 0x03 && data[10 + position] >= 0x00 && data[1 + position] <= 0x03 - } + /** + * 判断是否是TLS Client Hello + */ + fun isTLSClientHello(packetData: ByteBuffer): Boolean { + if (packetData.remaining() < 43) return false + val position = packetData.position() + val data = packetData.array() + if (data[position].toInt() != 0x16 /* handshake */) return false + if (data[1 + position].toInt() != 0x03) return false + return if (data[5 + position].toInt() != 0x01) false else data[9 + position].toInt() == 0x03 && data[10 + position] >= 0x00 && data[1 + position] <= 0x03 + } - /** - * 从TLS Client Hello 解析域名 - */ - fun getDomain(buffer: ByteBuffer): String? { - var offset = buffer.position() - val limit = buffer.limit() - //TLS Client Hello - if (buffer[offset].toInt() != 0x16) return null - //Skip 43 byte header - offset += 43 - if (limit < (offset + 1)) return null + /** + * 从TLS Client Hello 解析域名 + */ + fun getDomain(buffer: ByteBuffer): String? { + var offset = buffer.position() + val limit = buffer.limit() + //TLS Client Hello + if (buffer[offset].toInt() != 0x16) return null + //Skip 43 byte header + offset += 43 + if (limit < (offset + 1)) return null - //read session id - val sessionIDLength = buffer[offset++] - offset += sessionIDLength + //read session id + val sessionIDLength = buffer[offset++] + offset += sessionIDLength - //read cipher suites - if (offset + 2 > limit) return null - val cipherSuitesLength = buffer.getShort(offset) - offset += 2 - offset += cipherSuitesLength + //read cipher suites + if (offset + 2 > limit) return null + val cipherSuitesLength = buffer.getShort(offset) + offset += 2 + offset += cipherSuitesLength - //read Compression method. - if (offset + 1 > limit) return null - val compressionMethodLength = buffer[offset++].toInt() and 0xFF - offset += compressionMethodLength - if (offset > limit) return null + //read Compression method. + if (offset + 1 > limit) return null + val compressionMethodLength = buffer[offset++].toInt() and 0xFF + offset += compressionMethodLength + if (offset > limit) return null - //read Extensions - if (offset + 2 > limit) return null + //read Extensions + if (offset + 2 > limit) return null - val extensionsLength = buffer.getShort(offset) - offset += 2 - if (offset + extensionsLength > limit) return null + val extensionsLength = buffer.getShort(offset) + offset += 2 + if (offset + extensionsLength > limit) return null - var end: Int = offset + extensionsLength - end = min(end, limit) - while (offset + 4 <= end) { - val extensionType = buffer.getShort(offset) - val extensionLength = buffer.getShort(offset + 2) - offset += 4 - //server_name - if (extensionType.toInt() == 0) { - if (offset + 5 > limit) return null - val serverNameListLength = buffer.getShort(offset) - offset += 2 - if (offset > limit) return null - if (offset + serverNameListLength > limit) return null + var end: Int = offset + extensionsLength + end = min(end, limit) + while (offset + 4 <= end) { + val extensionType = buffer.getShort(offset) + val extensionLength = buffer.getShort(offset + 2) + offset += 4 + //server_name + if (extensionType.toInt() == 0) { + if (offset + 5 > limit) return null + val serverNameListLength = buffer.getShort(offset) + offset += 2 + if (offset > limit) return null + if (offset + serverNameListLength > limit) return null - val serverNameType = buffer[offset++] - val serverNameLength = buffer.getShort(offset) - offset += 2 - if (offset > limit || serverNameType.toInt() != 0) return null - if (offset + serverNameLength > limit) return null - val serverNameBytes = ByteArray(serverNameLength.toInt()) - buffer.get(serverNameBytes) - return String(serverNameBytes) - } else { - offset += extensionLength - } + val serverNameType = buffer[offset++] + val serverNameLength = buffer.getShort(offset) + offset += 2 + if (offset > limit || serverNameType.toInt() != 0) return null + if (offset + serverNameLength > limit) return null + val serverNameBytes = ByteArray(serverNameLength.toInt()) + buffer.get(serverNameBytes) + return String(serverNameBytes) + } else { + offset += extensionLength } - return null } + return null } } \ No newline at end of file diff --git a/lib/network/network.dart b/lib/network/network.dart index 9db7931..31a715c 100644 --- a/lib/network/network.dart +++ b/lib/network/network.dart @@ -133,6 +133,7 @@ class Server extends Network { void ssl(ChannelContext channelContext, Channel channel, Uint8List data) async { var hostAndPort = channelContext.host; try { + hostAndPort?.scheme = HostAndPort.httpsScheme; if (hostAndPort == null && TLS.getDomain(data) != null) { hostAndPort = HostAndPort.host(TLS.getDomain(data)!, 443); } @@ -140,7 +141,7 @@ class Server extends Network { Channel? remoteChannel = channelContext.serverChannel; - if (HostFilter.filter(hostAndPort?.host)) { + if (HostFilter.filter(hostAndPort?.host) || !configuration.enableSsl) { remoteChannel = remoteChannel ?? await channelContext.connectServerChannel(hostAndPort!, RelayHandler(channel)); relay(channel, remoteChannel); channel.pipeline.channelRead(channelContext, channel, data);