安卓vpn流量拦截

This commit is contained in:
wanghongenpin
2023-12-30 03:14:30 +08:00
parent ba4072691e
commit 22a31c286d
23 changed files with 2834 additions and 68 deletions

View File

@@ -12,6 +12,7 @@ import android.os.Build
import android.os.ParcelFileDescriptor
import android.util.Log
import androidx.core.app.NotificationCompat
import com.network.proxy.vpn.ProxyVpnThread
import com.network.proxy.vpn.socket.ProtectSocket
import com.network.proxy.vpn.socket.ProtectSocketHolder
@@ -21,6 +22,7 @@ import com.network.proxy.vpn.socket.ProtectSocketHolder
*/
class ProxyVpnService : VpnService(), ProtectSocket {
private var vpnInterface: ParcelFileDescriptor? = null
private var vpnThread: ProxyVpnThread? = null
companion object {
const val MAX_PACKET_LEN = 1500
@@ -86,6 +88,7 @@ class ProxyVpnService : VpnService(), ProtectSocket {
}
private fun disconnect() {
vpnThread?.run { stopThread() }
vpnInterface?.close()
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.N) {
stopForeground(STOP_FOREGROUND_REMOVE)
@@ -111,6 +114,12 @@ class ProxyVpnService : VpnService(), ProtectSocket {
ProtectSocketHolder.setProtectSocket(this)
showServiceNotification()
vpnThread = ProxyVpnThread(
vpnInterface!!,
proxyHost,
proxyPort
)
vpnThread!!.start()
isRunning = true
}

View File

@@ -0,0 +1,179 @@
package com.network.proxy.vpn
import android.util.Log
import com.network.proxy.vpn.socket.CloseableConnection
import com.network.proxy.vpn.transport.protocol.IP4Header
import com.network.proxy.vpn.transport.protocol.TCPHeader
import com.network.proxy.vpn.transport.protocol.UDPHeader
import com.network.proxy.vpn.util.PacketUtil
import java.io.ByteArrayOutputStream
import java.io.IOException
import java.nio.ByteBuffer
import java.nio.channels.SelectionKey
import java.nio.channels.spi.AbstractSelectableChannel
import kotlin.concurrent.Volatile
class Connection(
val protocol: Protocol,
val sourceIp: Int, val sourcePort: Int,
val destinationIp: Int, val destinationPort: Int,
private val connectionCloser: CloseableConnection
) {
var channel: AbstractSelectableChannel? = null
var selectionKey: SelectionKey? = null
//接收用于存储来自远程主机的数据的缓冲器
private val receivingStream: ByteArrayOutputStream = ByteArrayOutputStream()
//发送缓冲区用于存储要从vpn客户端发送到目标主机的数据
private val sendingStream: ByteArrayOutputStream = ByteArrayOutputStream()
var hasReceivedLastSegment = false
/**
* 是否初始化链接 针对代理判断协议延迟初始化
*/
var isInitConnect = false
//指示三向握手是否已完成
var isConnected = false
//从客户端接收的最后一个数据包
var lastIpHeader: IP4Header? = null
var lastTcpHeader: TCPHeader? = null
var lastUdpHeader: UDPHeader? = null
var timestampSender = 0
var timestampReplyTo = 0
//从客户端接收的序列
var recSequence: Long = 0
//在tcp选项内的SYN期间由客户端发送
var maxSegmentSize = 0
//跟踪我们发送给客户端的ack并等待客户端返回ack
var sendUnAck: Long = 0
//发送到客户端的下一个ack
var sendNext: Long = 0
//true when connection is about to be close
var isClosingConnection = false
//指示客户端的数据已准备好发送到目标
@Volatile
var isDataForSendingReady = false
//closing session and aborting connection, will be done by background task
@Volatile
var isAbortingConnection = false
//indicate that vpn client has sent FIN flag and it has been acked
var isAckedToFin = false
companion object {
fun getConnectionKey(
protocol: Protocol, destIp: Int, destPort: Int, sourceIp: Int, sourcePort: Int
): String {
return protocol.name + "|" + PacketUtil.intToIPAddress(sourceIp) + ":" + sourcePort +
"->" + PacketUtil.intToIPAddress(destIp) + ":" + destPort
}
}
// fun getConnectionKey(): String {
// return getConnectionKey(protocol, destinationIp, destinationIp, sourceIp, sourcePort)
// }
fun closeConnection() {
connectionCloser.closeConnection(this)
}
/**
* 设置要发送到目标服务器的数据
*/
@Synchronized
fun setSendingData(data: ByteBuffer): Int {
val remaining = data.remaining()
sendingStream.write(data.array(), data.position(), data.remaining())
return remaining
}
@Synchronized
fun addReceivedData(data: ByteArray?) {
try {
receivingStream.write(data)
} catch (e: IOException) {
Log.e(TAG, e.toString())
}
}
/**
* 获取缓冲区中接收到的所有数据并清空它。
*/
@Synchronized
fun getReceivedData(maxSize: Int): ByteArray? {
var data = receivingStream.toByteArray()
receivingStream.reset()
if (data.size > maxSize) {
val small = ByteArray(maxSize)
System.arraycopy(data, 0, small, 0, maxSize)
val len = data.size - maxSize
receivingStream.write(data, maxSize, len)
data = small
}
return data
}
/**
* buffer has more data for vpn client
*/
fun hasReceivedData(): Boolean {
return receivingStream.size() > 0
}
fun hasDataToSend(): Boolean {
return sendingStream.size() > 0
}
/**
* 出列数据以发送到服务器
*/
@Synchronized
fun getSendingData(): ByteArray? {
val data = sendingStream.toByteArray()
sendingStream.reset()
return data
}
fun cancelKey() {
selectionKey?.let {
synchronized(it) {
if (!it.isValid) return
it.cancel()
}
}
}
fun subscribeKey(op: Int) {
selectionKey?.let {
synchronized(it) {
if (!it.isValid) return
it.interestOps(it.interestOps() or op)
}
}
}
fun unsubscribeKey(op: Int) {
selectionKey?.let {
synchronized(it) {
if (!it.isValid) return
it.interestOps(it.interestOps() and op.inv())
}
}
}
}

View File

@@ -0,0 +1,510 @@
package com.network.proxy.vpn
import android.os.Build
import android.util.Log
import com.network.proxy.vpn.Connection.Companion.getConnectionKey
import com.network.proxy.vpn.socket.ClientPacketWriter
import com.network.proxy.vpn.socket.SocketNIODataService
import com.network.proxy.vpn.transport.icmp.ICMPPacket
import com.network.proxy.vpn.transport.icmp.ICMPPacketFactory
import com.network.proxy.vpn.transport.protocol.IP4Header
import com.network.proxy.vpn.transport.protocol.IPPacketFactory
import com.network.proxy.vpn.transport.protocol.TCPHeader
import com.network.proxy.vpn.transport.protocol.TCPPacketFactory
import com.network.proxy.vpn.transport.protocol.UDPPacketFactory
import com.network.proxy.vpn.util.PacketUtil.getOutput
import com.network.proxy.vpn.util.PacketUtil.intToIPAddress
import com.network.proxy.vpn.util.PacketUtil.isPacketCorrupted
import com.network.proxy.vpn.util.TLS.getDomain
import com.network.proxy.vpn.util.TLS.isTLSClientHello
import java.io.IOException
import java.net.InetAddress
import java.net.InetSocketAddress
import java.net.SocketAddress
import java.nio.ByteBuffer
import java.nio.channels.SelectionKey
import java.nio.channels.SocketChannel
import java.util.concurrent.ExecutorService
import java.util.concurrent.SynchronousQueue
import java.util.concurrent.ThreadPoolExecutor
import java.util.concurrent.TimeUnit
class ConnectionHandler(
private val manager: ConnectionManager,
private val nioService: SocketNIODataService,
private val writer: ClientPacketWriter
) {
private val pingThreadPool: ExecutorService = ThreadPoolExecutor(
1, 20, // 1 - 20 parallel pings max
60L, TimeUnit.SECONDS,
SynchronousQueue(),
ThreadPoolExecutor.DiscardPolicy() // Replace running pings if there's too many
)
/**
* Handle unknown raw IP packet data
*
* @param stream ByteBuffer to be read
*/
@Throws(IOException::class)
fun handlePacket(stream: ByteBuffer) {
val rawPacket = ByteArray(stream.limit())
stream[rawPacket, 0, stream.limit()]
stream.rewind()
val ipHeader = IPPacketFactory.createIP4Header(stream)
if (ipHeader == null) {
stream.rewind()
Log.w(TAG, "Malformed IP packet ")
return
}
if (ipHeader.protocol.toInt() == 6) {
handleTCPPacket(stream, ipHeader)
} else if (ipHeader.protocol.toInt() == 17) {
handleUDPPacket(stream, ipHeader)
} else if (ipHeader.protocol.toInt() == 1) {
handleICMPPacket(stream, ipHeader)
} else {
Log.w(TAG, "Unsupported IP protocol: " + ipHeader.protocol)
}
}
@Throws(IOException::class)
private fun handleUDPPacket(clientPacketData: ByteBuffer, ipHeader: IP4Header) {
val udpHeader = UDPPacketFactory.createUDPHeader(clientPacketData)
var connection = manager.getConnection(
Protocol.UDP,
ipHeader.destinationIP, udpHeader.destinationPort,
ipHeader.sourceIP, udpHeader.sourcePort
)
val newSession = connection == null
if (connection == null) {
connection = manager.createUDPConnection(
ipHeader.destinationIP, udpHeader.destinationPort,
ipHeader.sourceIP, udpHeader.sourcePort
)
}
synchronized(connection) {
connection.lastIpHeader = ipHeader
connection.lastUdpHeader = udpHeader
manager.addClientData(clientPacketData, connection)
connection.isDataForSendingReady = true
// We don't register the session until it's fully populated (as above)
if (newSession) nioService.registerSession(connection)
// Ping the NIO thread to write this, when the session is next writable
connection.subscribeKey(SelectionKey.OP_WRITE)
nioService.refreshSelect(connection)
}
manager.keepSessionAlive(connection)
}
/**
* 是否支持协议
*/
private val methods: List<String> =
mutableListOf("GET", "POST", "PUT", "DELETE", "HEAD", "OPTIONS", "TRACE", "CONNECT")
private fun supperProtocol(packetData: ByteBuffer): Boolean {
val position = packetData.position()
//判断是否是ssl握手
if (isTLSClientHello(packetData) && getDomain(packetData) != null) {
packetData.position(position)
return true
}
packetData.position(position)
for (method in methods) {
if (packetData.remaining() < method.length) {
continue
}
val bytes = ByteArray(method.length)
for (i in bytes.indices) {
bytes[i] = packetData[position + i]
}
if (method.equals(String(bytes), ignoreCase = true)) {
return true
}
}
return false
}
/**
* 获取代理地址
*/
private fun getProxyAddress(
packetData: ByteBuffer,
destinationIP: Int,
destinationPort: Int
): SocketAddress {
val supperProtocol = supperProtocol(packetData)
var socketAddress: SocketAddress? = null
if (supperProtocol) {
socketAddress = manager.proxyAddress
}
if (socketAddress == null) {
val ips = intToIPAddress(destinationIP)
socketAddress = InetSocketAddress(ips, destinationPort)
}
return socketAddress
}
@Throws(IOException::class)
private fun handleTCPPacket(clientPacketData: ByteBuffer, ip4Header: IP4Header) {
val tcpHeader = TCPPacketFactory.createTCPHeader(clientPacketData)
val dataLength = clientPacketData.limit() - clientPacketData.position()
val sourceIP = ip4Header.sourceIP
val destinationIP = ip4Header.destinationIP
val sourcePort = tcpHeader.getSourcePort()
val destinationPort = tcpHeader.getDestinationPort()
if (tcpHeader.isSYN()) {
// 3-way handshake + create new session
replySynAck(ip4Header, tcpHeader)
} else if (tcpHeader.isACK()) {
val key =
getConnectionKey(Protocol.TCP, destinationIP, destinationPort, sourceIP, sourcePort)
val connection = manager.getConnectionByKey(key)
if (connection == null) {
Log.w(TAG, "Ack for unknown session: $key")
if (tcpHeader.isFIN()) {
sendLastAck(ip4Header, tcpHeader)
} else if (!tcpHeader.isRST()) {
sendRstPacket(ip4Header, tcpHeader, dataLength)
}
return
}
synchronized(connection) {
connection.lastIpHeader = ip4Header
connection.lastTcpHeader = tcpHeader
//any data from client?
if (dataLength > 0) {
if (!connection.isInitConnect) {
connection.isInitConnect = true
val proxyAddress =
getProxyAddress(clientPacketData, destinationIP, destinationPort)
try {
val channel =
connection.channel as SocketChannel?
val connected = channel!!.connect(proxyAddress)
connection.isConnected = connected
nioService.registerSession(connection)
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.N) {
Log.d(
TAG,
"Proxy Initiate connecting key:" + key + " " + channel.localAddress + " to remote tcp server: " + channel.remoteAddress
)
}
} catch (e: Exception) {
val ips = intToIPAddress(destinationIP)
Log.w(
TAG,
"Failed to reconnect to $ips:$destinationPort",
e
)
}
}
//accumulate data from client
if (connection.recSequence == 0L || tcpHeader.sequenceNumber >= connection.recSequence) {
val addedLength = manager.addClientData(clientPacketData, connection)
//send ack to client only if new data was added
sendAck(ip4Header, tcpHeader, addedLength, connection)
} else {
sendAckForDisorder(ip4Header, tcpHeader, dataLength)
}
} else {
//an ack from client for previously sent data
acceptAck(tcpHeader, connection)
if (connection.isClosingConnection) {
sendFinAck(ip4Header, tcpHeader, connection)
} else if (connection.isAckedToFin && !tcpHeader.isFIN()) {
//the last ACK from client after FIN-ACK flag was sent
manager.closeConnection(
Protocol.TCP,
destinationIP,
destinationPort,
sourceIP,
sourcePort
)
// Log.d(TAG, "got last ACK after FIN, session is now closed.");
}
}
//received the last segment of data from vpn client
if (tcpHeader.isPSH()) {
// Tell the NIO thread to immediately send data to the destination
pushDataToDestination(connection, tcpHeader)
} else if (tcpHeader.isFIN()) {
//fin from vpn client is the last packet
//ack it
// Log.d(TAG, "FIN from vpn client, will ack it.");
ackFinAck(ip4Header, tcpHeader, connection)
} else if (tcpHeader.isRST()) {
resetTCPConnection(ip4Header, tcpHeader)
}
if (!connection.isAbortingConnection) {
manager.keepSessionAlive(connection)
}
}
} else if (tcpHeader.isFIN()) {
//case client sent FIN without ACK
val connection = manager.getConnection(
Protocol.TCP,
destinationIP,
destinationPort,
sourceIP,
sourcePort
)
if (connection == null) ackFinAck(
ip4Header,
tcpHeader,
null
) else manager.keepSessionAlive(connection)
} else if (tcpHeader.isRST()) {
resetTCPConnection(ip4Header, tcpHeader)
} else {
Log.d(TAG, "unknown TCP flag")
val str1 = getOutput(ip4Header, tcpHeader, clientPacketData.array())
Log.d(TAG, ">>>>>>>> Received from client <<<<<<<<<<")
Log.d(TAG, str1)
Log.d(TAG, ">>>>>>>>>>>>>>>>>>>end receiving from client>>>>>>>>>>>>>>>>>>>>>")
}
}
private fun sendRstPacket(ip: IP4Header, tcp: TCPHeader, dataLength: Int) {
val data = TCPPacketFactory.createRstData(ip, tcp, dataLength)
writer.write(data)
Log.d(
TAG, "Sent RST Packet to client with dest => " +
intToIPAddress(ip.destinationIP) + ":" +
tcp.getDestinationPort()
)
}
private fun sendLastAck(ip: IP4Header, tcp: TCPHeader) {
val data = TCPPacketFactory.createResponseAckData(ip, tcp, tcp.sequenceNumber + 1)
writer.write(data)
// Log.d(TAG,"Sent last ACK Packet to client with dest => " +
// PacketUtil.intToIPAddress(ip.getDestinationIP()) + ":" +
// tcp.getDestinationPort());
}
private fun ackFinAck(ip: IP4Header, tcp: TCPHeader, connection: Connection?) {
val ack = tcp.sequenceNumber + 1
val seq = tcp.ackNumber
val data = TCPPacketFactory.createFinAckData(ip, tcp, ack, seq, isFin = true, isAck = true)
writer.write(data)
if (connection != null) {
connection.cancelKey()
manager.closeConnection(connection)
// Log.d(TAG,"ACK to client's FIN and close session => "+PacketUtil.intToIPAddress(ip.getDestinationIP())+":"+tcp.getDestinationPort()
// +"-"+PacketUtil.intToIPAddress(ip.getSourceIP())+":"+tcp.getSourcePort());
}
}
private fun sendFinAck(ip: IP4Header, tcp: TCPHeader, connection: Connection) {
val ack = tcp.sequenceNumber
val seq = tcp.ackNumber
val data = TCPPacketFactory.createFinAckData(ip, tcp, ack, seq, isFin = true, isAck = false)
val stream = ByteBuffer.wrap(data)
writer.write(data)
Log.d(TAG, "00000000000 FIN-ACK packet data to vpn client 000000000000")
var vpnIp: IP4Header? = null
try {
vpnIp = IPPacketFactory.createIP4Header(stream)
} catch (e: Exception) {
e.printStackTrace()
}
var vpnTcp: TCPHeader? = null
try {
if (vpnIp != null) vpnTcp = TCPPacketFactory.createTCPHeader(stream)
} catch (e: Exception) {
e.printStackTrace()
}
if (vpnIp != null && vpnTcp != null) {
val logOut = getOutput(vpnIp, vpnTcp, data)
Log.d(TAG, logOut)
}
Log.d(TAG, "0000000000000 finished sending FIN-ACK packet to vpn client 000000000000")
connection.sendNext = seq + 1
//avoid re-sending it, from here client should take care the rest
connection.isClosingConnection = false
}
private fun pushDataToDestination(connection: Connection, tcp: TCPHeader) {
connection.isDataForSendingReady = true
connection.timestampReplyTo = tcp.timeStampSender
connection.timestampSender = System.currentTimeMillis().toInt()
// Ping the NIO thread to write this, when the session is next writable
connection.subscribeKey(SelectionKey.OP_WRITE)
nioService.refreshSelect(connection)
}
/**
* send acknowledgment packet to VPN client
*
* @param acceptedDataLength Data Length
*/
private fun sendAck(
ipHeader: IP4Header, tcpHeader: TCPHeader, acceptedDataLength: Int, connection: Connection
) {
val ackNumber = connection.recSequence + acceptedDataLength
connection.recSequence = ackNumber
val ackData = TCPPacketFactory.createResponseAckData(ipHeader, tcpHeader, ackNumber)
writer.write(ackData)
}
/**
* resend the last acknowledgment packet to VPN client, e.g. when an unexpected out of order
* packet arrives.
*/
private fun resendAck(connection: Connection) {
val data = TCPPacketFactory.createResponseAckData(
connection.lastIpHeader!!,
connection.lastTcpHeader!!,
connection.recSequence
)
writer.write(data)
}
private fun sendAckForDisorder(
ipHeader: IP4Header, tcpHeader: TCPHeader, acceptedDataLength: Int
) {
val ackNumber = tcpHeader.sequenceNumber + acceptedDataLength
Log.e(
TAG, "sent disorder ack, ack# " + tcpHeader.sequenceNumber +
" + " + acceptedDataLength + " = " + ackNumber
)
val data = TCPPacketFactory.createResponseAckData(ipHeader, tcpHeader, ackNumber)
writer.write(data)
}
/**
* acknowledge a packet.
*
* @param tcpHeader TCP Header
*/
private fun acceptAck(tcpHeader: TCPHeader, connection: Connection) {
val isCorrupted = isPacketCorrupted(tcpHeader)
// connection.setPacketCorrupted(isCorrupted);
if (isCorrupted) {
Log.e(TAG, "prev packet was corrupted, last ack# " + tcpHeader.ackNumber)
}
if (tcpHeader.ackNumber > connection.sendUnAck ||
tcpHeader.ackNumber == connection.sendNext
) {
// connection.setAcked(true);
connection.sendUnAck = tcpHeader.ackNumber
connection.recSequence = tcpHeader.sequenceNumber
connection.timestampReplyTo = tcpHeader.timeStampSender
connection.timestampSender = System.currentTimeMillis().toInt()
} else {
Log.d(
TAG,
"Not Accepting ack# " + tcpHeader.ackNumber + " , it should be: " + connection.sendNext
)
Log.d(TAG, "Prev sendUnAck: " + connection.sendUnAck)
// connection.setAcked(false);
}
}
/**
* set connection as aborting so that background worker will close it.
*
* @param ip IP
* @param tcp TCP
*/
private fun resetTCPConnection(ip: IP4Header, tcp: TCPHeader) {
val session = manager.getConnection(
Protocol.TCP,
ip.destinationIP, tcp.getDestinationPort(),
ip.sourceIP, tcp.getSourcePort()
)
if (session != null) {
synchronized(session) { session.isAbortingConnection = true }
}
}
/**
* create a new client's session and SYN-ACK packet data to respond to client
*/
@Throws(IOException::class)
private fun replySynAck(ipHeader: IP4Header, tcpHeader: TCPHeader) {
ipHeader.identification = 0
val packet = TCPPacketFactory.createSynAckPacketData(ipHeader, tcpHeader)
val tcpTransport = packet.transportHeader as TCPHeader
val connection = manager.createTCPConnection(
ipHeader.destinationIP, tcpHeader.getDestinationPort(),
ipHeader.sourceIP, tcpHeader.getSourcePort()
)
if (connection.lastIpHeader != null) {
// We have an existing session for this connection! We've somehow received a SYN
// for an existing socket (or some kind of other race). We resend the last ACK
// for this session, rejecting this SYN. Not clear why this happens, but it can.
resendAck(connection)
return
}
synchronized(connection) {
connection.maxSegmentSize = tcpTransport.maxSegmentSize.toInt()
connection.sendUnAck = tcpTransport.sequenceNumber
connection.sendNext = tcpTransport.sequenceNumber + 1
//client initial sequence has been incremented by 1 and set to ack
connection.recSequence = tcpTransport.ackNumber
connection.lastIpHeader = ipHeader
connection.lastTcpHeader = tcpHeader
if (connection.isInitConnect) {
nioService.registerSession(connection)
}
writer.write(packet.buffer)
}
}
private fun handleICMPPacket(clientPacketData: ByteBuffer, ipHeader: IP4Header) {
val requestPacket = ICMPPacketFactory.parseICMPPacket(clientPacketData)
Log.d(TAG, "Got an ICMP ping packet, type $requestPacket")
if (requestPacket.type == ICMPPacket.DESTINATION_UNREACHABLE_TYPE) {
// This is a packet from the phone, telling somebody that a destination is unreachable.
// Might be caused by issues on our end, but it's unclear what kind of issues. Regardless,
// we can't send ICMP messages ourselves or react usefully, so we drop these silently.
return
} else require(requestPacket.type == ICMPPacket.ECHO_REQUEST_TYPE) {
// We only actually support outgoing ping packets. Loudly drop anything else:
"Unknown ICMP type (" + requestPacket.type + "). Only echo requests are supported"
}
pingThreadPool.execute(object : Runnable {
override fun run() {
try {
if (!isReachable(intToIPAddress(ipHeader.destinationIP))) {
Log.d(TAG, "Failed ping, ignoring")
return
}
val response = ICMPPacketFactory.buildSuccessPacket(requestPacket)
// Flip the address
val destination = ipHeader.destinationIP
val source = ipHeader.sourceIP
ipHeader.sourceIP = destination
ipHeader.destinationIP = source
val responseData = ICMPPacketFactory.packetToBuffer(ipHeader, response)
Log.d(TAG, "Successful ping response")
writer.write(responseData)
} catch (e: Exception) {
Log.w(TAG, "Handling ICMP failed with " + e.message)
return
}
}
private fun isReachable(ipAddress: String): Boolean {
return try {
InetAddress.getByName(ipAddress).isReachable(10000)
} catch (e: IOException) {
false
}
}
})
}
}

View File

@@ -0,0 +1,159 @@
package com.network.proxy.vpn
import android.os.Build
import android.util.Log
import com.network.proxy.vpn.socket.CloseableConnection
import com.network.proxy.vpn.socket.Constant
import com.network.proxy.vpn.socket.ProtectSocketHolder.Companion.protect
import com.network.proxy.vpn.util.PacketUtil
import java.io.IOException
import java.net.InetSocketAddress
import java.net.SocketAddress
import java.nio.ByteBuffer
import java.nio.channels.DatagramChannel
import java.nio.channels.SocketChannel
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.ConcurrentMap
/**
* 管理VPN客户端的连接
*/
class ConnectionManager : CloseableConnection {
private val table: ConcurrentMap<String, Connection> = ConcurrentHashMap()
var proxyAddress: InetSocketAddress? = null
var DEFAULT_PORTS: List<Int> = listOf(
80, // HTTP
443, // HTTPS
8080, // Common local dev ports
8000, 8080, 8888, 9000 // Common local dev ports
)
override fun closeConnection(connection: Connection) {
closeConnection(
connection.protocol, connection.destinationIp, connection.destinationPort,
connection.sourceIp, connection.sourcePort
)
}
/**
* 从内存中删除连接,然后关闭套接字。
*
*/
fun closeConnection(protocol: Protocol, ip: Int, port: Int, srcIp: Int, srcPort: Int) {
val key = Connection.getConnectionKey(protocol, ip, port, srcIp, srcPort)
val session: Connection? = table.remove(key)
session?.let {
val channel = session.channel
try {
channel?.close()
} catch (e: IOException) {
e.printStackTrace()
}
}
}
fun getConnection(
protocol: Protocol, ip: Int, port: Int, srcIp: Int, srcPort: Int
): Connection? {
val key = Connection.getConnectionKey(protocol, ip, port, srcIp, srcPort)
return getConnectionByKey(key)
}
fun getConnectionByKey(key: String?): Connection? {
return table[key]
}
/**
* 创建tcp连接
*/
fun createTCPConnection(ip: Int, port: Int, srcIp: Int, srcPort: Int): Connection {
val key = Connection.getConnectionKey(Protocol.TCP, ip, port, srcIp, srcPort)
val existingConnection: Connection? = table[key]
if (existingConnection != null) {
return existingConnection
}
val connection = Connection(Protocol.TCP, srcIp, srcPort, ip, port, this)
val channel: SocketChannel = SocketChannel.open()
channel.socket().keepAlive = true
channel.socket().tcpNoDelay = true
channel.socket().soTimeout = 0
channel.socket().receiveBufferSize = Constant.MAX_RECEIVE_BUFFER_SIZE
channel.configureBlocking(false)
Log.d(TAG, "created new SocketChannel for $key")
protect(channel.socket())
connection.channel = channel
val socketAddress: SocketAddress? = null
// if (!DEFAULT_PORTS.contains(port)) {
// socketAddress = new InetSocketAddress(ips, port);
// }
// if (!DEFAULT_PORTS.contains(port)) {
// socketAddress = new InetSocketAddress(ips, port);
// }
connection.isInitConnect = socketAddress != null
if (socketAddress != null) {
val connected = channel.connect(socketAddress)
connection.isConnected = connected
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.N) {
Log.d(
TAG,
"Initiate connecting " + channel.localAddress + " to remote tcp server: " + channel.remoteAddress
)
}
}
table[key] = connection
return connection
}
@Throws(IOException::class)
fun createUDPConnection(ip: Int, port: Int, srcIp: Int, srcPort: Int): Connection {
val keys = Connection.getConnectionKey(Protocol.UDP, ip, port, srcIp, srcPort)
val existingConnection: Connection? = table[keys]
if (existingConnection != null) return existingConnection
val connection = Connection(Protocol.UDP, srcIp, srcPort, ip, port, this)
val channel: DatagramChannel = DatagramChannel.open()
channel.socket().soTimeout = 0
channel.configureBlocking(false)
protect(channel.socket())
connection.channel = channel
// Initiate connection early to reduce latency
val ips = PacketUtil.intToIPAddress(ip)
val socketAddress: SocketAddress = InetSocketAddress(ips, port)
channel.connect(socketAddress)
connection.isConnected = channel.isConnected
table[keys] = connection
return connection
}
/**
* 添加来自客户端的数据该数据稍后将在接收到PSH标志时发送到目的服务器。
*/
fun addClientData(buffer: ByteBuffer, session: Connection): Int {
return if (buffer.limit() <= buffer.position()) 0 else session.setSendingData(buffer)
}
/**
* 阻止java垃圾收集器收集会话
*/
fun keepSessionAlive(connection: Connection) {
val key = Connection.getConnectionKey(
connection.protocol, connection.destinationIp, connection.destinationPort,
connection.sourceIp, connection.sourcePort
)
table[key] = connection
}
}

View File

@@ -0,0 +1,6 @@
package com.network.proxy.vpn;
public enum Protocol {
TCP,
UDP
}

View File

@@ -0,0 +1,105 @@
package com.network.proxy.vpn
import android.os.ParcelFileDescriptor
import android.util.Log
import com.network.proxy.ProxyVpnService.Companion.MAX_PACKET_LEN
import com.network.proxy.vpn.socket.ClientPacketWriter
import com.network.proxy.vpn.socket.SocketNIODataService
import java.io.FileInputStream
import java.io.FileOutputStream
import java.io.InterruptedIOException
import java.net.InetSocketAddress
import java.nio.ByteBuffer
/**
* VPN线程负责处理VPN接收到的数据包
*/
class ProxyVpnThread(
vpnInterface: ParcelFileDescriptor,
proxyHost: String,
proxyPort: Int,
) : Thread("Vpn thread") {
companion object {
const val TAG = "ProxyVpnThread"
}
@Volatile
private var running = false
private val vpnReadChannel = FileInputStream(vpnInterface.fileDescriptor).channel
// 此VPN接收的来自上游服务器的数据包
private val vpnWriteStream = FileOutputStream(vpnInterface.fileDescriptor)
private val vpnPacketWriter = ClientPacketWriter(vpnWriteStream)
private val vpnPacketWriterThread = Thread(vpnPacketWriter)
// Background service & task for non-blocking socket
private val nioService = SocketNIODataService(vpnPacketWriter)
private val dataServiceThread = Thread(nioService, "Socket NIO thread")
private val manager = ConnectionManager().apply {
//流量转发到代理地址
this.proxyAddress = InetSocketAddress(proxyHost, proxyPort)
}
private val handler = ConnectionHandler(manager, nioService, vpnPacketWriter)
private var currentThread: Thread? = null
override fun run() {
Log.i(TAG, "Vpn thread starting")
currentThread = currentThread()
dataServiceThread.start()
vpnPacketWriterThread.start()
val readBuffer = ByteBuffer.allocate(MAX_PACKET_LEN)
running = true
while (running) {
try {
val length = vpnReadChannel.read(readBuffer)
if (length > 0) {
try {
readBuffer.flip()
val byteArray = ByteArray(length)
readBuffer.get(byteArray)
val packet = ByteBuffer.wrap(byteArray)
handler.handlePacket(packet)
} catch (e: Exception) {
val errorMessage = (e.message ?: e.toString())
Log.e(TAG, errorMessage, e)
}
readBuffer.clear()
} else {
sleep(50)
}
} catch (e: InterruptedException) {
Log.i(TAG, "Sleep interrupted: " + e.message)
} catch (e: InterruptedIOException) {
Log.i(TAG, "Read interrupted: " + e.message)
} catch (e: Exception) {
val errorMessage = (e.message ?: e.toString())
Log.e(TAG, errorMessage, e)
}
}
Log.i(TAG, "Vpn thread stop")
}
@Synchronized
fun stopThread() {
if (running) {
running = false
nioService.shutdown()
dataServiceThread.interrupt()
vpnPacketWriter.shutdown()
vpnPacketWriterThread.interrupt()
currentThread?.interrupt()
}
}
}

View File

@@ -0,0 +1,46 @@
package com.network.proxy.vpn.socket
import android.util.Log
import java.io.FileOutputStream
import java.io.IOException
import java.util.concurrent.BlockingDeque
import java.util.concurrent.LinkedBlockingDeque
import kotlin.concurrent.Volatile
class ClientPacketWriter(private val clientWriter: FileOutputStream) : Runnable {
companion object {
private const val TAG: String = "ClientPacketWriter"
private const val MAX_PACKET_LEN = 32767
}
@Volatile
private var shutdown = false
private val packetQueue: BlockingDeque<ByteArray> = LinkedBlockingDeque()
fun write(data: ByteArray) {
if (data.size > MAX_PACKET_LEN) throw Error("Packet too large")
packetQueue.addLast(data)
}
fun shutdown() {
this.shutdown = true
}
override fun run() {
while (!this.shutdown) {
try {
val data: ByteArray = this.packetQueue.take()
try {
this.clientWriter.write(data)
} catch (e: IOException) {
Log.e(TAG, "Error writing $shutdown data.length bytes to the VPN")
e.printStackTrace()
this.packetQueue.addFirst(data) // Put the data back, so it's resent
Thread.sleep(10) // Add an arbitrary tiny pause, in case that helps
}
} catch (ignored: InterruptedException) {
}
}
}
}

View File

@@ -0,0 +1,10 @@
package com.network.proxy.vpn.socket
import com.network.proxy.vpn.Connection
interface CloseableConnection {
/**
* 关闭连接
*/
fun closeConnection(session: Connection)
}

View File

@@ -0,0 +1,5 @@
package com.network.proxy.vpn.socket
object Constant {
const val MAX_RECEIVE_BUFFER_SIZE = 65535
}

View File

@@ -0,0 +1,211 @@
package com.network.proxy.vpn.socket;
import androidx.annotation.NonNull;
import android.util.Log;
import com.network.proxy.vpn.Connection;
import com.network.proxy.vpn.TagKt;
import com.network.proxy.vpn.transport.protocol.IP4Header;
import com.network.proxy.vpn.transport.protocol.TCPHeader;
import com.network.proxy.vpn.transport.protocol.TCPPacketFactory;
import com.network.proxy.vpn.transport.protocol.UDPPacketFactory;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.ClosedByInterruptException;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.DatagramChannel;
import java.nio.channels.NotYetConnectedException;
import java.nio.channels.SelectionKey;
import java.nio.channels.SocketChannel;
import java.nio.channels.spi.AbstractSelectableChannel;
/**
* Takes a session, and reads all available upstream data back into it.
* Used by the NIO thread, and run synchronously as part of that non-blocking loop.
*/
class SocketChannelReader {
private final String TAG = TagKt.getTAG(this);
private final ClientPacketWriter writer;
public SocketChannelReader(ClientPacketWriter writer) {
this.writer = writer;
}
public void read(Connection connection) {
AbstractSelectableChannel channel = connection.getChannel();
if (channel instanceof SocketChannel) {
readTCP(connection);
} else if (channel instanceof DatagramChannel) {
readUDP(connection);
} else {
return;
}
// Resubscribe to reads, so that we're triggered again if more data arrives later.
connection.subscribeKey(SelectionKey.OP_READ);
if (connection.isAbortingConnection()) {
Log.d(TAG, "removing aborted connection -> " + connection);
connection.cancelKey();
if (channel instanceof SocketChannel) {
try {
SocketChannel socketChannel = (SocketChannel) channel;
if (socketChannel.isConnected()) {
socketChannel.close();
}
} catch (IOException e) {
Log.e(TAG, e.toString());
}
} else {
try {
DatagramChannel datagramChannel = (DatagramChannel) channel;
if (datagramChannel.isConnected()) {
datagramChannel.close();
}
} catch (IOException e) {
e.printStackTrace();
}
}
connection.closeConnection();
}
}
private void readTCP(@NonNull Connection connection) {
if (connection.isAbortingConnection()) {
return;
}
SocketChannel channel = (SocketChannel) connection.getChannel();
ByteBuffer buffer = ByteBuffer.allocate(Constant.MAX_RECEIVE_BUFFER_SIZE);
int len;
try {
do {
len = channel.read(buffer);
if (len > 0) { //-1 mean it reach the end of stream
sendToRequester(buffer, len, connection);
buffer.clear();
} else if (len == -1) {
// Log.d(TAG,"End of data from remote server, will send FIN to client");
Log.d(TAG, "send FIN to: " + connection);
sendFin(connection);
connection.setAbortingConnection(true);
}
} while (len > 0);
} catch (NotYetConnectedException e) {
Log.e(TAG, "socket not connected");
} catch (ClosedByInterruptException e) {
Log.e(TAG, "ClosedByInterruptException reading SocketChannel: " + e.getMessage());
} catch (ClosedChannelException e) {
Log.e(TAG, "ClosedChannelException reading SocketChannel: " + e.getMessage());
} catch (IOException e) {
Log.e(TAG, "Error reading data from SocketChannel: " + e.getMessage());
connection.setAbortingConnection(true);
}
}
private void sendToRequester(ByteBuffer buffer, int dataSize, @NonNull Connection connection) {
// Last piece of data is usually smaller than MAX_RECEIVE_BUFFER_SIZE. We use this as a
// trigger to set PSH on the resulting TCP packet that goes to the VPN.
connection.setHasReceivedLastSegment(dataSize < Constant.MAX_RECEIVE_BUFFER_SIZE);
buffer.limit(dataSize);
buffer.flip();
// TODO should allocate new byte array?
byte[] data = new byte[dataSize];
System.arraycopy(buffer.array(), 0, data, 0, dataSize);
connection.addReceivedData(data);
//pushing all data to vpn client
while (connection.hasReceivedData()) {
pushDataToClient(connection);
}
}
/**
* create packet data and send it to VPN client
*/
private void pushDataToClient(@NonNull Connection connection) {
if (!connection.hasReceivedData()) {
//no data to send
Log.d(TAG, "no data for vpn client");
}
IP4Header ipHeader = connection.getLastIpHeader();
TCPHeader tcpheader = connection.getLastTcpHeader();
// TODO What does 60 mean?
int max = connection.getMaxSegmentSize() - 60;
if (max < 1) {
max = 1024;
}
byte[] packetBody = connection.getReceivedData(max);
if (packetBody != null && packetBody.length > 0) {
long unAck = connection.getSendNext();
long nextUnAck = connection.getSendNext() + packetBody.length;
connection.setSendNext((int) nextUnAck);
//we need this data later on for retransmission
// connection.setUnackData(packetBody);
// connection.setResendPacketCounter(0);
byte[] data = TCPPacketFactory.createResponsePacketData(ipHeader,
tcpheader, packetBody, connection.getHasReceivedLastSegment(),
connection.getRecSequence(), (int) unAck,
connection.getTimestampSender(), connection.getTimestampReplyTo());
writer.write(data);
}
}
private void sendFin(Connection connection) {
final IP4Header ipHeader = connection.getLastIpHeader();
final TCPHeader tcpheader = connection.getLastTcpHeader();
final byte[] data = TCPPacketFactory.INSTANCE.createFinData(ipHeader, tcpheader,
connection.getRecSequence(), connection.getSendNext(),
connection.getTimestampSender(), connection.getTimestampReplyTo());
writer.write(data);
}
private void readUDP(Connection connection) {
DatagramChannel channel = (DatagramChannel) connection.getChannel();
ByteBuffer buffer = ByteBuffer.allocate(Constant.MAX_RECEIVE_BUFFER_SIZE);
int len;
try {
do {
if (connection.isAbortingConnection()) {
break;
}
len = channel.read(buffer);
if (len > 0) {
buffer.limit(len);
buffer.flip();
//create UDP packet
byte[] data = new byte[len];
System.arraycopy(buffer.array(), 0, data, 0, len);
byte[] packetData = UDPPacketFactory.createResponsePacket(
connection.getLastIpHeader(), connection.getLastUdpHeader(), data);
//write to client
writer.write(packetData);
buffer.clear();
}
} while (len > 0);
} catch (NotYetConnectedException ex) {
Log.e(TAG, "failed to read from unconnected UDP socket");
} catch (IOException e) {
Log.e(TAG, "Failed to read from UDP socket, aborting connection");
connection.setAbortingConnection(true);
}
}
}

View File

@@ -0,0 +1,148 @@
package com.network.proxy.vpn.socket;
import androidx.annotation.NonNull;
import android.util.Log;
import com.network.proxy.vpn.Connection;
import com.network.proxy.vpn.TagKt;
import com.network.proxy.vpn.transport.protocol.TCPPacketFactory;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.DatagramChannel;
import java.nio.channels.NotYetConnectedException;
import java.nio.channels.SelectionKey;
import java.nio.channels.SocketChannel;
import java.nio.channels.spi.AbstractSelectableChannel;
/**
* Takes a VPN session, and writes all received data from it to the upstream channel.
* <p>
* If any writes fail, it resubscribes to OP_WRITE, and tries again next time
* that fires (as soon as the channel is ready for more data).
* <p>
* Used by the NIO thread, and run synchronously as part of that non-blocking loop.
*/
public class SocketChannelWriter {
private final String TAG = TagKt.getTAG(this);
private final ClientPacketWriter writer;
SocketChannelWriter(ClientPacketWriter writer) {
this.writer = writer;
}
public void write(@NonNull Connection connection) {
AbstractSelectableChannel channel = connection.getChannel();
if (channel instanceof SocketChannel) {
writeTCP(connection);
} else if(channel instanceof DatagramChannel) {
writeUDP(connection);
} else {
// We only ever create TCP & UDP channels, so this should never happen
throw new IllegalArgumentException("Unexpected channel type: " + channel);
}
if (connection.isAbortingConnection()) {
Log.d(TAG,"removing aborted connection -> " + connection);
connection.cancelKey();
if (channel instanceof SocketChannel) {
try {
SocketChannel socketChannel = (SocketChannel) channel;
if (socketChannel.isConnected()) {
socketChannel.close();
}
} catch (IOException e) {
e.printStackTrace();
}
} else {
try {
DatagramChannel datagramChannel = (DatagramChannel) channel;
if (datagramChannel.isConnected()) {
datagramChannel.close();
}
} catch (IOException e) {
e.printStackTrace();
}
}
connection.closeConnection();
}
}
private void writeUDP(Connection connection) {
try {
writePendingData(connection);
// Date dt = new Date();
// connection.connectionStartTime = dt.getTime();
}catch(NotYetConnectedException ex2){
connection.setAbortingConnection(true);
Log.e(TAG,"Error writing to unconnected-UDP server, will abort current connection: "+ex2.getMessage());
} catch (IOException e) {
connection.setAbortingConnection(true);
e.printStackTrace();
Log.e(TAG,"Error writing to UDP server, will abort connection: "+e.getMessage());
}
}
private void writeTCP(Connection connection) {
try {
writePendingData(connection);
} catch (NotYetConnectedException ex) {
Log.e(TAG,"failed to write to unconnected socket: " + ex.getMessage());
} catch (IOException e) {
Log.e(TAG,"Error writing to server: " + e);
//close connection with vpn client
byte[] rstData = TCPPacketFactory.INSTANCE.createRstData(
connection.getLastIpHeader(), connection.getLastTcpHeader(), 0);
writer.write(rstData);
//remove session
Log.e(TAG,"failed to write to remote socket, aborting connection");
connection.setAbortingConnection(true);
}
}
private void writePendingData(Connection connection) throws IOException {
if (!connection.hasDataToSend()) return;
AbstractSelectableChannel channel = connection.getChannel();
byte[] data = connection.getSendingData();
ByteBuffer buffer = ByteBuffer.allocate(data.length);
buffer.put(data);
buffer.flip();
while (buffer.hasRemaining()) {
int bytesWritten = channel instanceof SocketChannel
? ((SocketChannel) channel).write(buffer)
: ((DatagramChannel) channel).write(buffer);
if (bytesWritten == 0) {
break;
}
}
if (buffer.hasRemaining()) {
// The channel's own buffer is full, so we have to save this for later.
Log.i(TAG, buffer.remaining() + " bytes unwritten for " + channel);
// Put the remaining data from the buffer back into the session
connection.setSendingData(buffer.compact());
// Subscribe to WRITE events, so we know when this is ready to resume.
connection.subscribeKey(SelectionKey.OP_WRITE);
} else {
// All done, all good -> wait until the next TCP PSH / UDP packet
connection.setDataForSendingReady(false);
// We don't need to know about WRITE events any more, we've written all our data.
// This is safe from races with new data, due to the session lock in NIO.
connection.unsubscribeKey(SelectionKey.OP_WRITE);
}
}
}

View File

@@ -0,0 +1,247 @@
package com.network.proxy.vpn.socket;
import android.util.Log;
import com.network.proxy.vpn.Connection;
import com.network.proxy.vpn.TagKt;
import java.io.IOException;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.DatagramChannel;
import java.nio.channels.SelectableChannel;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.SocketChannel;
import java.nio.channels.spi.AbstractSelectableChannel;
import java.util.Iterator;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
/**
* A service that single-threadedly processes the events around our session connections,
* entirely via non-blocking NIO.
* <p>
* It uses a Selector that fires on outgoing socket events (connected, readable, writable),
* handles the resulting operations, and keeps those subscriptions up to date.
*/
public class SocketNIODataService implements Runnable {
private final String TAG = TagKt.getTAG(this);
private final ReentrantLock nioSelectionLock = new ReentrantLock();
private final ReentrantLock nioHandlingLock = new ReentrantLock();
private final Selector selector = Selector.open();
private final SocketChannelReader reader;
private final SocketChannelWriter writer;
private volatile boolean shutdown = false;
public SocketNIODataService(ClientPacketWriter clientPacketWriter) throws IOException {
reader = new SocketChannelReader(clientPacketWriter);
writer = new SocketChannelWriter(clientPacketWriter);
}
@Override
public void run() {
Log.d(TAG,"SocketNIODataService starting in background...");
runTask();
}
public void registerSession(Connection connection) throws ClosedChannelException {
AbstractSelectableChannel channel = connection.getChannel();
boolean isConnected = channel instanceof DatagramChannel
? ((DatagramChannel) channel).isConnected()
: ((SocketChannel) channel).isConnected();
// Log.i(TAG, "Registering new session: " + session);
Lock selectorLock = lockSelector(selector);
try {
SelectionKey selectionKey = channel.register(selector,
isConnected
? SelectionKey.OP_READ
: SelectionKey.OP_CONNECT
);
connection.setSelectionKey(selectionKey);
selectionKey.attach(connection);
// Log.d(TAG, "Registered selector successfully");
} finally {
selectorLock.unlock();
}
}
private Lock lockSelector(Selector selector) {
boolean gotSelectionLock = nioSelectionLock.tryLock();
if (gotSelectionLock) return nioSelectionLock;
nioHandlingLock.lock(); // Ensure the NIO thread can't do anything on wakeup
selector.wakeup();
nioSelectionLock.lock(); // Actually get the lock we want
nioHandlingLock.unlock(); // Release the handling lock, which we no longer care about
return nioSelectionLock;
}
/**
* If the selector is currently select()ing, wake it up (e.g. to register changes to
* interestOps). If it's not (and so it probably will select() very soon anyway) do nothing.
* This is designed to be run after changing readyOps, to ensure the new ops get monitored
* immediately (and fire immediately, if already ready). Without this, that blocks.
*/
public void refreshSelect(Connection connection) {
boolean gotLock = nioSelectionLock.tryLock();
if (!gotLock) {
connection.getSelectionKey().selector().wakeup();
} else {
nioSelectionLock.unlock();
}
}
/**
* Shut down the NIO thread
*/
public void shutdown(){
this.shutdown = true;
selector.wakeup();
}
private void runTask(){
Log.i(TAG, "NIO selector is running...");
while(!shutdown){
try {
nioSelectionLock.lockInterruptibly();
selector.select();
} catch (IOException e) {
Log.e(TAG,"Error in Selector.select(): " + e.getMessage());
try {
Thread.sleep(100);
} catch (InterruptedException ex) {
Log.e(TAG, e.toString());
}
continue;
} catch (InterruptedException ex) {
Log.i(TAG, "Select() interrupted");
} finally {
if (nioSelectionLock.isHeldByCurrentThread()) {
nioSelectionLock.unlock();
}
}
if (shutdown) {
break;
}
// A lock here makes it possible to reliably grab the selection lock above
nioHandlingLock.lock();
try {
Iterator<SelectionKey> iterator = selector.selectedKeys().iterator();
while (iterator.hasNext()) {
SelectionKey key = iterator.next();
Connection connection = ((Connection) key.attachment());
synchronized (connection) { // Sessions are locked during processing (no VPN data races)
try {
processSelectionKey(key);
} catch (IOException e) {
synchronized (key) {
key.cancel();
}
}
}
iterator.remove();
if (shutdown) {
break;
}
}
} finally {
nioHandlingLock.unlock();
}
}
Log.i(TAG, "NIO selector shutdown");
}
private void processSelectionKey(SelectionKey key) throws IOException {
if (!key.isValid()) {
Log.d(TAG,"Invalid SelectionKey");
return;
}
SelectableChannel channel = key.channel();
Connection connection = ((Connection) key.attachment());
if (connection == null) {
Log.w(TAG, "Key fired with no session attached");
return;
}
if (channel instanceof SocketChannel && !connection.isConnected() && key.isConnectable()) {
SocketChannel socketChannel = (SocketChannel) channel;
if (socketChannel.isConnectionPending()) {
boolean connected = socketChannel.finishConnect();
connection.setConnected(connected);
} else {
throw new IllegalStateException("TCP channels must either be connected or pending connection");
}
}
if (isConnected(channel)) {
processConnectedSelection(key, connection);
}
}
private boolean isConnected(SelectableChannel channel) {
if (channel instanceof DatagramChannel) {
return ((DatagramChannel) channel).isConnected();
} else if (channel instanceof SocketChannel) {
return ((SocketChannel) channel).isConnected();
} else {
throw new IllegalArgumentException("isConnected on unexpected channel type: " + channel);
}
}
private void processConnectedSelection(SelectionKey key, Connection connection) {
// Whilst connected, we always want READ and not CONNECT events
connection.unsubscribeKey(SelectionKey.OP_CONNECT);
connection.subscribeKey(SelectionKey.OP_READ);
processSelectorRead(key, connection);
processPendingWrite(key, connection);
}
private void processSelectorRead(SelectionKey selectionKey, Connection connection) {
boolean canRead;
synchronized (selectionKey) {
// There's a race here that requires a lock, as isReadable requires isValid
canRead = selectionKey.isValid() && selectionKey.isReadable();
}
if (canRead) reader.read(connection);
}
private void processPendingWrite(SelectionKey selectionKey, Connection connection) {
// Nothing to write? Skip this entirely, and make sure we're not subscribed
if (!connection.hasDataToSend() || !connection.isDataForSendingReady()) {
connection.unsubscribeKey(SelectionKey.OP_WRITE);
return;
}
boolean canWrite;
synchronized (selectionKey) {
// There's a race here that requires a lock, as isReadable requires isValid
canWrite = selectionKey.isValid() && selectionKey.isWritable();
}
if (canWrite) {
connection.unsubscribeKey(SelectionKey.OP_WRITE);
writer.write(connection); // This will resubscribe to OP_WRITE if it can't complete
}
}
}

View File

@@ -0,0 +1,7 @@
package com.network.proxy.vpn.transport
import com.network.proxy.vpn.transport.protocol.IP4Header
import com.network.proxy.vpn.transport.protocol.TransportHeader
class Packet(var ipHeader: IP4Header, var transportHeader: TransportHeader, var buffer: ByteArray) {
}

View File

@@ -0,0 +1,46 @@
package com.network.proxy.vpn.transport.icmp;
import androidx.annotation.NonNull;
public class ICMPPacket {
// Two ICMP packets we can handle: simple ping & pong
public static final byte ECHO_REQUEST_TYPE = 8;
public static final byte ECHO_SUCCESS_TYPE = 0;
// One very common packet we ignore: connection rejection. Unclear why this happens,
// random incoming connections that the phone tries to reply to? Nothing we can do though,
// as we can't forward ICMP onwards, and we can't usefully respond or react.
public static final byte DESTINATION_UNREACHABLE_TYPE = 3;
public final byte type;
final byte code; // 0 for request, 0 for success, 0 - 15 for error subtypes
final int checksum;
final int identifier;
final int sequenceNumber;
final byte[] data;
ICMPPacket(
int type,
int code,
int checksum,
int identifier,
int sequenceNumber,
byte[] data
) {
this.type = (byte) type;
this.code = (byte) code;
this.checksum = checksum;
this.identifier = identifier;
this.sequenceNumber = sequenceNumber;
this.data = data;
}
@NonNull
public String toString() {
return "ICMP packet type " + type + "/" + code + " id:" + identifier +
" seq:" + sequenceNumber + " and " + data.length + " bytes of data";
}
}

View File

@@ -0,0 +1,80 @@
package com.network.proxy.vpn.transport.icmp;
import androidx.annotation.NonNull;
import com.network.proxy.vpn.transport.protocol.IP4Header;
import com.network.proxy.vpn.util.PacketUtil;
import java.io.ByteArrayOutputStream;
import java.nio.ByteBuffer;
public class ICMPPacketFactory {
public static ICMPPacket parseICMPPacket(@NonNull ByteBuffer stream) {
final byte type = stream.get();
final byte code = stream.get();
final int checksum = stream.getShort();
final int identifier = stream.getShort();
final int sequenceNumber = stream.getShort();
final byte[] data = new byte[stream.remaining()];
stream.get(data);
return new ICMPPacket(type, code, checksum, identifier, sequenceNumber, data);
}
public static ICMPPacket buildSuccessPacket(ICMPPacket requestPacket) {
return new ICMPPacket(
0,
0,
0,
requestPacket.identifier,
requestPacket.sequenceNumber,
requestPacket.data
);
}
public static byte[] packetToBuffer(IP4Header ipHeader, ICMPPacket packet) {
byte[] ipData = ipHeader.toBytes();
ByteArrayOutputStream icmpDataBuffer = new ByteArrayOutputStream();
icmpDataBuffer.write(packet.type);
icmpDataBuffer.write(packet.code);
icmpDataBuffer.write(asShortBytes(0 /* checksum placeholder */), 0, 2);
if (packet.type == ICMPPacket.ECHO_REQUEST_TYPE || packet.type == ICMPPacket.ECHO_SUCCESS_TYPE) {
icmpDataBuffer.write(asShortBytes(packet.identifier), 0, 2);
icmpDataBuffer.write(asShortBytes(packet.sequenceNumber), 0, 2);
byte[] extraData = packet.data;
icmpDataBuffer.write(extraData, 0, extraData.length);
} else {
throw new IllegalArgumentException("Can't serialize unrecognized ICMP packet type");
}
byte[] icmpPacketData = icmpDataBuffer.toByteArray();
byte[] checksum = PacketUtil.INSTANCE.calculateChecksum(icmpPacketData, 0, icmpPacketData.length);
ByteBuffer resultBuffer = ByteBuffer.allocate(ipData.length + icmpPacketData.length);
resultBuffer.put(ipData);
resultBuffer.put(icmpPacketData);
// Replace the checksum placeholder
resultBuffer.position(ipData.length + 2);
resultBuffer.put(checksum);
resultBuffer.position(0);
byte[] result = new byte[resultBuffer.remaining()];
resultBuffer.get(result);
return result;
}
private static byte[] asShortBytes(int value) {
return ByteBuffer.allocate(2).putShort((short) value).array();
}
}

View File

@@ -0,0 +1,142 @@
package com.network.proxy.vpn.transport.protocol
import java.nio.ByteBuffer
import java.nio.ByteOrder
/**
* IPv4报头的数据结构。
*/
data class IP4Header(
var ipVersion: Byte = 0, //对于IPv4其值为4因此命名为IPv4。 4bit
private var internetHeaderLength: Byte = 0, //头部长度 4bit
private var diffTypeOfService: Byte, //差分服务代码点 =>6位
private var ecn: Byte = 0, //显式拥塞通知ECN
var totalLength: Int = 0, //此IP数据包的总长度 16bit
var identification: Int = 0, //主要用于唯一标识单个IP数据报的片段组。 16bit
private var mayFragment: Boolean, // 1bit 用于指示数据报是否可以分段。
private var lastFragment: Boolean, // 1bit 用于指示数据报是否是片段中的最后一个。
var fragmentOffset: Short = 0, //13bit指定特定片段相对于原始未分段的IP数据报的开始的偏移量。
private var timeToLive: Byte = 0, //用于防止数据报持续存在。8bit
var protocol: Byte = 0, //定义IP数据报的数据部分中使用的协议。 8bit
var headerChecksum: Int = 0, //用于对头部进行错误检查的16位字段。 16bit
var sourceIP: Int = 0, //发送者的IPv4地址。 32bit
var destinationIP: Int = 0 //接收者的IPv4地址。 32bit
) {
//用于控制或识别片段的3比特字段。
//bit 0: 保留;必须为零
//bit 1: Don't Fragment (DF)
//bit 2: More Fragments (MF)
private var flag: Byte = initFlag()
private fun initFlag(): Byte {
var initFlag = 0
if (mayFragment) {
initFlag = 0x40
}
if (lastFragment) {
initFlag = (initFlag or 0x20)
}
return initFlag.toByte()
}
fun setMayFragment(mayFragment: Boolean) {
this.mayFragment = mayFragment
flag = if (mayFragment) {
(flag.toInt() or 0x40).toByte()
} else {
(flag.toInt() and 0xBF).toByte()
}
}
fun getIPHeaderLength(): Int {
return internetHeaderLength * 4
}
fun copy(): IP4Header {
return IP4Header(
ipVersion, internetHeaderLength, diffTypeOfService, ecn, totalLength, identification,
mayFragment, lastFragment, fragmentOffset, timeToLive, protocol, headerChecksum,
sourceIP, destinationIP
)
}
fun toBytes(): ByteArray {
val buffer = ByteBuffer.allocate(getIPHeaderLength())
buffer.order(ByteOrder.BIG_ENDIAN)
val versionAndHeaderLength = (ipVersion.toInt() shl 4) + internetHeaderLength
buffer.put(versionAndHeaderLength.toByte())
val typeOfService: Byte = (diffTypeOfService.toInt() shl 2 and (ecn
.toInt() and 0xFF)).toByte()
buffer.put(typeOfService)
buffer.putShort(totalLength.toShort())
buffer.putShort(identification.toShort())
//组合标志和部分片段偏移
buffer.put((fragmentOffset.toInt() shr 8 and 0x1F or flag.toInt()).toByte())
buffer.put(fragmentOffset.toByte())
buffer.put(timeToLive)
buffer.put(protocol)
buffer.putShort(headerChecksum.toShort())
buffer.putInt(sourceIP)
buffer.putInt(destinationIP)
return buffer.array()
}
}
object IPPacketFactory {
private const val IP4_HEADER_SIZE = 20
private const val IP4_VERSION = 0x04
/**
* 从给定的ByteBuffer流创建IPv4标头
*/
fun createIP4Header(buffer: ByteBuffer): IP4Header? {
if (buffer.remaining() < IP4_HEADER_SIZE) {
throw IllegalArgumentException("IP header byte array must have at least $IP4_HEADER_SIZE bytes")
}
val versionAndHeaderLength: Byte = buffer.get()
val ipVersion = (versionAndHeaderLength.toInt() shr 4).toByte()
if (ipVersion.toInt() != IP4_VERSION) {
// throw IllegalArgumentException("Invalid IP version $ipVersion")
return null
}
val internetHeaderLength = (versionAndHeaderLength.toInt() and 0x0F).toByte()
val typeOfService = buffer.get().toInt()
val diffTypeOfService: Byte = (typeOfService shr 2).toByte();
val ecn: Byte = (typeOfService and 0x03).toByte()
val totalLength: Int = buffer.getShort().toInt()
val identification: Int = buffer.getShort().toInt()
val flagsAndFragmentOffset: Short = buffer.getShort()
val mayFragment = flagsAndFragmentOffset.toInt() and 0x4000 != 0
val lastFragment = flagsAndFragmentOffset.toInt() and 0x2000 != 0
val fragmentOffset = (flagsAndFragmentOffset.toInt() and 0x1FFF).toShort()
val timeToLive: Byte = buffer.get()
val protocol: Byte = buffer.get()
val checksum: Int = buffer.getShort().toInt()
val sourceIp: Int = buffer.getInt()
val desIp: Int = buffer.getInt()
if (internetHeaderLength > 5) {
// drop the IP option
for (i in 0 until (internetHeaderLength - 5)) {
buffer.getInt()
}
}
return IP4Header(
ipVersion, internetHeaderLength, diffTypeOfService, ecn, totalLength, identification,
mayFragment, lastFragment, fragmentOffset, timeToLive, protocol, checksum,
sourceIp, desIp
)
}
}

View File

@@ -0,0 +1,212 @@
package com.network.proxy.vpn.transport.protocol
import java.nio.ByteBuffer
import java.nio.ByteOrder
/**
* TCP报头的数据结构。
*/
class TCPHeader(
private var sourcePort: Int = 0, //源端口号 16bit
private var destinationPort: Int = 0, //目的端口号 16bit
var sequenceNumber: Long = 0, //序列号 32bit
var ackNumber: Long = 0, //确认号 32bit
var dataOffset: Int = 0, //数据偏移4bit
var isNS: Boolean = false, //ECN-nonce concealment protection (experimental: see RFC 3540)
var flags: Int = 0, //标志位 9bit
var windowSize: Int = 0, //窗口大小 16bit
var checksum: Int = 0, //校验和 16bit
private var urgentPointer: Int = 0, //紧急指针 16bit
var options: ByteArray? = null //选项
) : TransportHeader {
//options
var maxSegmentSize: Short = 0
private var windowScale: Byte = 0
private var isSelectiveAckPermitted = false
var timeStampSender = 0
var timeStampReplyTo = 0
companion object {
private const val END_OF_OPTIONS_LIST: Byte = 0
private const val NO_OPERATION: Byte = 1
private const val MAX_SEGMENT_SIZE: Byte = 2
private const val WINDOW_SCALE: Byte = 3
private const val SELECTIVE_ACK_PERMITTED: Byte = 4
private const val TIME_STAMP: Byte = 8
}
fun isSYN(): Boolean {
return flags and 0x02 != 0
}
fun isFIN(): Boolean {
return flags and 0x01 != 0
}
fun isRST(): Boolean {
return flags and 0x04 != 0
}
fun isPSH(): Boolean {
return flags and 0x08 != 0
}
fun isACK(): Boolean {
return flags and 0x10 != 0
}
fun isURG(): Boolean {
return flags and 0x20 != 0
}
fun isECE(): Boolean {
return flags and 0x40 != 0
}
fun isCWR(): Boolean {
return flags and 0x80 != 0
}
fun setIsRST(isRST: Boolean) {
flags = if (isRST) {
(flags or 0x04)
} else {
(flags and 0xFB)
}
}
fun setIsSYN(isSYN: Boolean) {
flags = if (isSYN) {
(flags or 0x02)
} else {
(flags and 0xFD)
}
}
fun setIsFIN(isFIN: Boolean) {
flags = if (isFIN) {
(flags or 0x01)
} else {
(flags and 0xFE)
}
}
fun setIsPSH(isPSH: Boolean) {
flags = if (isPSH) {
(flags or 0x08)
} else {
(flags and 0xF7)
}
}
fun setIsACK(isACK: Boolean) {
flags = if (isACK) {
(flags or 0x10)
} else {
(flags and 0xEF)
}
}
fun getTCPHeaderLength(): Int {
return dataOffset * 4
}
fun toBytes(): ByteArray {
val tcpHeaderLength = getTCPHeaderLength()
val tcpHeader = ByteArray(tcpHeaderLength)
val byteBuffer = ByteBuffer.wrap(tcpHeader)
byteBuffer.order(ByteOrder.BIG_ENDIAN)
byteBuffer.putShort(sourcePort.toShort())
byteBuffer.putShort(destinationPort.toShort())
byteBuffer.putInt(sequenceNumber.toInt())
byteBuffer.putInt(ackNumber.toInt())
//is ns and data offset
byteBuffer.put(((dataOffset shl 4) and 0xF0 or (if (isNS) 0x1 else 0x0)).toByte())
byteBuffer.put(flags.toByte())
byteBuffer.putShort(windowSize.toShort())
byteBuffer.putShort(checksum.toShort())
byteBuffer.putShort(urgentPointer.toShort())
// encodeTcpOptions()?.let {
// byteBuffer.put(it)
// }
return tcpHeader
}
fun copy(): TCPHeader {
return TCPHeader(
sourcePort, destinationPort, sequenceNumber, ackNumber,
dataOffset, isNS, flags, windowSize, checksum, urgentPointer,
options
)
}
private fun handleTcpOptions() {
if (options == null) {
return
}
var index = 0
val packet = ByteBuffer.wrap(options!!)
val optionsSize = options!!.size
while (index < optionsSize) {
val optionKind = packet.get()
index++
if (optionKind == END_OF_OPTIONS_LIST || optionKind == NO_OPERATION) {
continue
}
val size = packet.get()
index++
when (optionKind) {
MAX_SEGMENT_SIZE -> {
maxSegmentSize = packet.getShort()
index += 2
}
WINDOW_SCALE -> {
windowScale = packet.get()
index++
}
SELECTIVE_ACK_PERMITTED -> isSelectiveAckPermitted = true
TIME_STAMP -> {
timeStampSender = packet.getInt()
timeStampReplyTo = packet.getInt()
index += 8
}
else -> {
skipRemainingOptions(packet, size.toInt())
index = index + size - 2
}
}
}
}
private fun skipRemainingOptions(packet: ByteBuffer, size: Int) {
for (i in 2 until size) {
packet.get()
}
}
override fun getSourcePort(): Int {
return sourcePort
}
override fun getDestinationPort(): Int {
return destinationPort
}
fun setSourcePort(sourcePort: Int) {
this.sourcePort = sourcePort
}
fun setDestinationPort(destinationPort: Int) {
this.destinationPort = destinationPort
}
}

View File

@@ -0,0 +1,285 @@
package com.network.proxy.vpn.transport.protocol
import com.network.proxy.vpn.transport.Packet
import com.network.proxy.vpn.util.PacketUtil
import java.nio.ByteBuffer
import java.util.concurrent.ThreadLocalRandom
object TCPPacketFactory {
private const val TCP_HEADER_LENGTH = 20
/**
* 从tcp报文创建tcpHeader
*/
@JvmStatic
fun createTCPHeader(byteBuffer: ByteBuffer): TCPHeader {
if (byteBuffer.remaining() < TCP_HEADER_LENGTH) {
throw IllegalArgumentException("Invalid TCP Header Length")
}
val sourcePort: Int = byteBuffer.getShort().toInt() and 0xFFFF
val destinationPort: Int = byteBuffer.getShort().toInt() and 0xFFFF
val sequenceNumber: Long = byteBuffer.getInt().toLong()
val ackNumber: Long = byteBuffer.getInt().toLong()
val dataOffsetAndReserved = byteBuffer.get()
val dataOffset = (dataOffsetAndReserved.toInt() and 0xF0) shr 4
val isNs: Boolean = dataOffsetAndReserved.toInt() and 0x1 > 0x0
val flags = byteBuffer.get().toInt()
val window = byteBuffer.short.toInt()
val checksum = byteBuffer.short.toInt()
val urgentPointer = byteBuffer.short.toInt()
var optionsAndPadding: ByteArray? = null
val optionsSize = dataOffset - 5
if (optionsSize > 0) {
optionsAndPadding = ByteArray(optionsSize * 4)
byteBuffer.get(optionsAndPadding, 0, optionsSize * 4)
}
return TCPHeader(
sourcePort, destinationPort, sequenceNumber, ackNumber,
dataOffset, isNs, flags, window, checksum, urgentPointer, optionsAndPadding
)
}
/**
* 创建带有RST标志的数据包以便在需要重置时发送到客户端。
*/
fun createRstData(ipHeader: IP4Header, tcpHeader: TCPHeader, dataLength: Int): ByteArray {
val ip = ipHeader.copy()
val tcp = tcpHeader.copy()
var ackNumber: Long = 0
var seqNumber: Long = 0
if (tcp.ackNumber > 0) {
seqNumber = tcp.ackNumber
} else {
ackNumber = tcp.sequenceNumber + dataLength
}
tcp.ackNumber = ackNumber
tcp.sequenceNumber = seqNumber
//将IP从源翻转到目标
flipIp(ip, tcp)
ip.identification = 0
tcp.flags = 0
tcp.isNS = false
tcp.setIsRST(true)
tcp.options = null
tcp.windowSize = 0
//重新计算IP长度
val totalLength = ip.getIPHeaderLength() + tcp.getTCPHeaderLength()
ip.totalLength = totalLength
return createPacketData(ip, tcp, null)
}
/**
* 创建数据包数据以发送回客户端
*/
@JvmStatic
fun createResponsePacketData(
ipHeader: IP4Header, tcpHeader: TCPHeader, packetData: ByteArray?, isPsh: Boolean,
ackNumber: Long, seqNumber: Long, timeSender: Int, timeReplyTo: Int
): ByteArray {
val ip = ipHeader.copy()
val tcp = tcpHeader.copy()
flipIp(ip, tcp)
tcp.ackNumber = ackNumber
tcp.sequenceNumber = seqNumber
ip.identification = PacketUtil.getPacketId()
//总是发送ACK
//ACK is always sent
tcp.setIsACK(true)
tcp.setIsSYN(false)
tcp.setIsPSH(isPsh)
tcp.setIsFIN(false)
tcp.timeStampSender = timeSender
tcp.timeStampReplyTo = timeReplyTo
var totalLength = ip.getIPHeaderLength() + tcp.getTCPHeaderLength()
if (packetData != null) {
totalLength += packetData.size
}
ip.totalLength = totalLength
return createPacketData(ip, tcp, packetData)
}
/**
* 向客户端确认服务器已收到请求。
*/
@JvmStatic
fun createResponseAckData(
ipHeader: IP4Header, tcpHeader: TCPHeader, ackToClient: Long
): ByteArray {
val ip = ipHeader.copy()
val tcp = tcpHeader.copy()
flipIp(ip, tcp)
val seqNumber = tcp.ackNumber
tcp.ackNumber = ackToClient
tcp.sequenceNumber = seqNumber
ip.identification = PacketUtil.getPacketId()
//ACK
tcp.setIsACK(true)
tcp.setIsSYN(false)
tcp.setIsPSH(false)
tcp.setIsFIN(false)
ip.totalLength = ip.getIPHeaderLength() + tcp.getTCPHeaderLength()
return createPacketData(ip, tcp, null)
}
//将IP从源翻转到目标
private fun flipIp(ip: IP4Header, tcp: TCPHeader) {
val sourceIp = ip.destinationIP
val destIp = ip.sourceIP
val sourcePort = tcp.getDestinationPort()
val destPort = tcp.getSourcePort()
ip.destinationIP = destIp
ip.sourceIP = sourceIp
tcp.setDestinationPort(destPort)
tcp.setSourcePort(sourcePort)
}
/**
* 通过写回客户端流创建SYN-ACK数据包数据
*/
fun createSynAckPacketData(ipHeader: IP4Header, tcpHeader: TCPHeader): Packet {
val ip = ipHeader.copy()
val tcp = tcpHeader.copy()
flipIp(ip, tcp)
//ack = received sequence + 1
val ackNumber = tcpHeader.sequenceNumber + 1
tcp.ackNumber = ackNumber
//服务器生成的初始序列号
val seqNumber = ThreadLocalRandom.current().nextLong(0, 100000)
tcp.sequenceNumber = seqNumber
//SYN-ACK
tcp.setIsACK(true)
tcp.setIsSYN(true)
tcp.timeStampReplyTo = tcp.timeStampSender
tcp.timeStampSender = PacketUtil.currentTime
return Packet(ip, tcp, createPacketData(ip, tcp, null))
}
/**
* 创建发送到客户端的FIN-ACK
*/
fun createFinAckData(
ipHeader: IP4Header, tcpHeader: TCPHeader, ackToClient: Long,
seqToClient: Long, isFin: Boolean, isAck: Boolean
): ByteArray {
val ip = ipHeader.copy()
val tcp = tcpHeader.copy()
flipIp(ip, tcp)
tcp.ackNumber = ackToClient
tcp.sequenceNumber = seqToClient
ip.identification = PacketUtil.getPacketId()
//ACK
tcp.setIsACK(isAck)
tcp.setIsSYN(false)
tcp.setIsPSH(false)
tcp.setIsFIN(isFin)
ip.totalLength = ip.getIPHeaderLength() + tcp.getTCPHeaderLength()
return createPacketData(ip, tcp, null)
}
fun createFinData(
ip: IP4Header, tcp: TCPHeader, ackNumber: Long, seqNumber: Long,
timeSender: Int, timeReplyTo: Int
): ByteArray {
//将IP从源翻转到目标
flipIp(ip, tcp)
tcp.ackNumber = ackNumber
tcp.sequenceNumber = seqNumber
ip.identification = PacketUtil.getPacketId()
tcp.timeStampReplyTo = timeReplyTo
tcp.timeStampSender = timeSender
tcp.flags = 0
tcp.isNS = false
tcp.setIsACK(true)
tcp.setIsFIN(true)
tcp.options = null
//窗口大小应为零
tcp.windowSize = 0
ip.totalLength = ip.getIPHeaderLength() + tcp.getTCPHeaderLength()
return createPacketData(ip, tcp, null)
}
/**
* 从tcpHeader创建tcp报文
*/
private fun createPacketData(ipHeader: IP4Header, tcpHeader: TCPHeader, data: ByteArray?):
ByteArray {
val dataLength = data?.size ?: 0
val buffer =
ByteBuffer.allocate(ipHeader.getIPHeaderLength() + tcpHeader.getTCPHeaderLength() + dataLength)
val ipBuffer = ipHeader.toBytes()
val tcpBuffer = tcpHeader.toBytes()
buffer.put(ipBuffer)
buffer.put(tcpBuffer)
data?.let { buffer.put(it) }
val zero = byteArrayOf(0, 0)
//计算前先将校验和清零
buffer.position(10)
buffer.put(zero)
val ipChecksum = PacketUtil.calculateChecksum(buffer.array(), 0, ipBuffer.size)
buffer.position(10)
buffer.put(ipChecksum)
val tcpStart = ipBuffer.size
buffer.position(tcpStart + 16)
buffer.put(zero)
val tcpChecksum = PacketUtil.calculateTCPHeaderChecksum(
buffer.array(), tcpStart, tcpBuffer.size + dataLength,
ipHeader.destinationIP, ipHeader.sourceIP
)
//将新的校验和写回阵列
buffer.position(tcpStart + 16)
buffer.put(tcpChecksum)
return buffer.array()
}
}

View File

@@ -0,0 +1,6 @@
package com.network.proxy.vpn.transport.protocol
interface TransportHeader {
fun getSourcePort(): Int
fun getDestinationPort(): Int
}

View File

@@ -0,0 +1,89 @@
package com.network.proxy.vpn.transport.protocol
import com.network.proxy.vpn.util.PacketUtil
import java.nio.ByteBuffer
/**
* UDP报头的数据结构。
*/
data class UDPHeader(
var sourcePort: Int = 0, //源端口号 16bit
var destinationPort: Int = 0, //目的端口号 16bit
var length: Int = 0, //UDP数据报长度 16bit
var checksum: Int = 0 //校验和 16bit
)
object UDPPacketFactory {
@JvmStatic
fun createUDPHeader(stream: ByteBuffer): UDPHeader {
require(stream.remaining() >= 8) { "Minimum UDP header is 8 bytes." }
val srcPort = stream.getShort().toInt() and 0xffff
val destPort = stream.getShort().toInt() and 0xffff
val length = stream.getShort().toInt() and 0xffff
val checksum = stream.getShort().toInt()
return UDPHeader(srcPort, destPort, length, checksum)
}
/**
* 创建用于响应vpn客户端的数据包
*/
@JvmStatic
fun createResponsePacket(ip: IP4Header, udp: UDPHeader, packetData: ByteArray?): ByteArray {
val buffer: ByteArray
var udpLen = 8
if (packetData != null) {
udpLen += packetData.size
}
val srcPort = udp.destinationPort
val destPort = udp.sourcePort
val checksum: Short = 0
val ipHeader = ip.copy()
val srcIp = ip.destinationIP
val destIp = ip.sourceIP
ipHeader.setMayFragment(false)
ipHeader.sourceIP = srcIp
ipHeader.destinationIP = destIp
ipHeader.identification = PacketUtil.getPacketId()
//ip的长度是整个数据包的长度 => IP header length + UDP header length (8) + UDP body length
val totalLength = ipHeader.getIPHeaderLength() + udpLen
ipHeader.totalLength = totalLength
buffer = ByteArray(totalLength)
val ipData = ipHeader.toBytes()
// clear IP checksum
ipData[11] = 0
ipData[10] = 0
//calculate checksum for IP header
val ipChecksum = PacketUtil.calculateChecksum(ipData, 0, ipData.size)
//write result of checksum back to buffer
System.arraycopy(ipChecksum, 0, ipData, 10, 2)
System.arraycopy(ipData, 0, buffer, 0, ipData.size)
//copy UDP header to buffer
var start = ipData.size
val intContainer = ByteArray(4)
PacketUtil.writeIntToBytes(srcPort, intContainer, 0)
//extract the last two bytes of int value
System.arraycopy(intContainer, 2, buffer, start, 2)
start += 2
PacketUtil.writeIntToBytes(destPort, intContainer, 0)
System.arraycopy(intContainer, 2, buffer, start, 2)
start += 2
PacketUtil.writeIntToBytes(udpLen, intContainer, 0)
System.arraycopy(intContainer, 2, buffer, start, 2)
start += 2
PacketUtil.writeIntToBytes(checksum.toInt(), intContainer, 0)
System.arraycopy(intContainer, 2, buffer, start, 2)
start += 2
//now copy udp data
if (packetData != null) System.arraycopy(packetData, 0, buffer, start, packetData.size)
return buffer
}
}

View File

@@ -0,0 +1,265 @@
package com.network.proxy.vpn.util
import android.util.Log
import com.network.proxy.vpn.formatTag
import com.network.proxy.vpn.transport.protocol.IP4Header
import com.network.proxy.vpn.transport.protocol.TCPHeader
import java.nio.ByteBuffer
import java.nio.ByteOrder
/**
* Helper class to perform various useful task
*
* @author Borey Sao
* Date: May 8, 2014
*/
object PacketUtil {
@get:Synchronized
private var packetId = 0
fun getPacketId() = packetId++
val currentTime: Int
get() = (System.currentTimeMillis() / 1000).toInt()
/**
* convert int to byte array
* [...](https://docs.oracle.com/javase/tutorial/java/nutsandbolts/datatypes.html)
*
* @param value int value 32 bits
* @param buffer array of byte to write to
* @param offset position to write to
*/
fun writeIntToBytes(value: Int, buffer: ByteArray, offset: Int) {
if (buffer.size - offset < 4) {
return
}
buffer[offset] = (value ushr 24 and 0x000000FF).toByte()
buffer[offset + 1] = (value shr 16 and 0x000000FF).toByte()
buffer[offset + 2] = (value shr 8 and 0x000000FF).toByte()
buffer[offset + 3] = (value and 0x000000FF).toByte()
}
/**
* convert array of max 4 bytes to int
*
* @param buffer byte array
* @param start Starting point to be read in byte array
* @param length Length to be read
* @return value of int
*/
fun getNetworkInt(buffer: ByteArray, start: Int, length: Int): Int {
var value = 0
var end = start + Math.min(length, 4)
if (end > buffer.size) end = buffer.size
for (i in start until end) {
value = value or (buffer[i].toInt() and 0xFF)
if (i < end - 1) value = value shl 8
}
return value
}
/**
* validate TCP header checksum
*
* @param source Source Port
* @param destination Destination Port
* @param data Payload
* @param tcpLength TCP Header length
* @return boolean
*/
fun isValidTCPChecksum(
source: Int, destination: Int,
data: ByteArray, tcpLength: Short, tcpOffset: Int
): Boolean {
var buffersize = tcpLength + 12
var isodd = false
if (buffersize % 2 != 0) {
buffersize++
isodd = true
}
val buffer = ByteBuffer.allocate(buffersize)
buffer.putInt(source)
buffer.putInt(destination)
buffer.put(0.toByte()) //reserved => 0
buffer.put(6.toByte()) //TCP protocol => 6
buffer.putShort(tcpLength)
buffer.put(data, tcpOffset, tcpLength.toInt())
if (isodd) {
buffer.put(0.toByte())
}
return isValidIPChecksum(buffer.array(), buffersize)
}
/**
* validate IP Header checksum
*
* @param data byte stream
* @return boolean
*/
private fun isValidIPChecksum(data: ByteArray, length: Int): Boolean {
var start = 0
var sum = 0
while (start < length) {
sum += getNetworkInt(data, start, 2)
start = start + 2
}
//carry over one's complement
while (sum shr 16 > 0) sum = (sum and 0xffff) + (sum shr 16)
//flip the bit to get one' complement
sum = sum.inv()
val buffer = ByteBuffer.allocate(4)
buffer.putInt(sum)
return buffer.getShort(2).toInt() == 0
}
fun calculateChecksum(data: ByteArray, offset: Int, length: Int): ByteArray {
var start = offset
var sum = 0
while (start < length) {
sum += getNetworkInt(data, start, 2)
start = start + 2
}
//carry over one's complement
while (sum shr 16 > 0) {
sum = (sum and 0xffff) + (sum shr 16)
}
//flip the bit to get one' complement
sum = sum.inv()
//extract the last two byte of int
val checksum = ByteArray(2)
checksum[0] = (sum shr 8).toByte()
checksum[1] = sum.toByte()
return checksum
}
fun calculateTCPHeaderChecksum(
data: ByteArray,
offset: Int,
tcplength: Int,
destip: Int,
sourceip: Int
): ByteArray {
var buffersize = tcplength + 12
var odd = false
if (buffersize % 2 != 0) {
buffersize++
odd = true
}
val buffer = ByteBuffer.allocate(buffersize)
buffer.order(ByteOrder.BIG_ENDIAN)
//create virtual header
buffer.putInt(sourceip)
buffer.putInt(destip)
buffer.put(0.toByte()) //reserved => 0
buffer.put(6.toByte()) //tcp protocol => 6
buffer.putShort(tcplength.toShort())
//add actual header + data
buffer.put(data, offset, tcplength)
//padding last byte to zero
if (odd) {
buffer.put(0.toByte())
}
val tcparray = buffer.array()
return calculateChecksum(tcparray, 0, buffersize)
}
fun intToIPAddress(addressInt: Int): String {
return (addressInt ushr 24 and 0x000000FF).toString() + "." +
(addressInt ushr 16 and 0x000000FF) + "." +
(addressInt ushr 8 and 0x000000FF) + "." +
(addressInt and 0x000000FF)
}
fun getOutput(
ipHeader: IP4Header, tcpheader: TCPHeader,
packetData: ByteArray
): String {
val tcpLength = (packetData.size -
ipHeader.getIPHeaderLength()).toShort()
val isValidChecksum = isValidTCPChecksum(
ipHeader.sourceIP, ipHeader.destinationIP,
packetData, tcpLength, ipHeader.getIPHeaderLength()
)
val isValidIPChecksum = isValidIPChecksum(
packetData,
ipHeader.getIPHeaderLength()
)
val packetBodyLength = (packetData.size - ipHeader.getIPHeaderLength()
- tcpheader.getTCPHeaderLength())
val str = StringBuilder("\r\nIP Version: ")
.append(ipHeader.ipVersion.toInt())
.append("\r\nProtocol: ").append(ipHeader.protocol.toInt())
.append("\r\nID# ").append(ipHeader.identification)
.append("\r\nTotal Length: ").append(ipHeader.totalLength)
.append("\r\nData Length: ").append(packetBodyLength)
.append("\r\nDest: ").append(intToIPAddress(ipHeader.destinationIP))
.append(":").append(tcpheader.getDestinationPort())
.append("\r\nSrc: ").append(intToIPAddress(ipHeader.sourceIP))
.append(":").append(tcpheader.getSourcePort())
.append("\r\nACK: ").append(tcpheader.ackNumber)
.append("\r\nSeq: ").append(tcpheader.sequenceNumber)
.append("\r\nIP Header length: ").append(ipHeader.getIPHeaderLength())
.append("\r\nTCP Header length: ").append(tcpheader.getTCPHeaderLength())
.append("\r\nACK: ").append(tcpheader.isACK())
.append("\r\nSYN: ").append(tcpheader.isSYN())
.append("\r\nCWR: ").append(tcpheader.isCWR())
.append("\r\nECE: ").append(tcpheader.isECE())
.append("\r\nFIN: ").append(tcpheader.isFIN())
.append("\r\nNS: ").append(tcpheader.isNS)
.append("\r\nPSH: ").append(tcpheader.isPSH())
.append("\r\nRST: ").append(tcpheader.isRST())
.append("\r\nURG: ").append(tcpheader.isURG())
.append("\r\nIP checksum: ").append(ipHeader.headerChecksum)
.append("\r\nIs Valid IP Checksum: ").append(isValidIPChecksum)
.append("\r\nTCP Checksum: ").append(tcpheader.checksum)
.append("\r\nIs Valid TCP checksum: ").append(isValidChecksum)
.append("\r\nFragment Offset: ").append(ipHeader.fragmentOffset.toInt())
.append("\r\nWindow: ").append(tcpheader.windowSize)
.append("\r\nData Offset: ").append(tcpheader.dataOffset)
return str.toString()
}
/**
* detect packet corruption flag in tcp options sent from client ACK
*
* @param tcpHeader TCPHeader
* @return boolean
*/
fun isPacketCorrupted(tcpHeader: TCPHeader): Boolean {
val options = tcpHeader.options
if (options != null) {
var i = 0
while (i < options.size) {
val kind = options[i]
if (kind.toInt() == 0 || kind.toInt() == 1) {
} else if (kind.toInt() == 2) {
i += 3
} else if (kind.toInt() == 3 || kind.toInt() == 14) {
i += 2
} else if (kind.toInt() == 4) {
i++
} else if (kind.toInt() == 5 || kind.toInt() == 15) {
i = i + options[++i] - 2
} else if (kind.toInt() == 8) {
i += 9
} else if (kind.toInt() == 23) {
return true
} else {
Log.e(
formatTag(PacketUtil::class.java.name),
"unknown option: $kind"
)
}
i++
}
}
return false
}
}

View File

@@ -4,84 +4,82 @@ import java.nio.ByteBuffer
import kotlin.math.min
class TLS {
object TLS {
companion object {
/**
* 判断是否是TLS Client Hello
*/
fun isTLSClientHello(packetData: ByteBuffer): Boolean {
if (packetData.remaining() < 43) return false
val position = packetData.position()
val data = packetData.array()
if (data[position].toInt() != 0x16 /* handshake */) return false
if (data[1 + position].toInt() != 0x03) return false
return if (data[5 + position].toInt() != 0x01) false else data[9 + position].toInt() == 0x03 && data[10 + position] >= 0x00 && data[1 + position] <= 0x03
}
/**
* 判断是否是TLS Client Hello
*/
fun isTLSClientHello(packetData: ByteBuffer): Boolean {
if (packetData.remaining() < 43) return false
val position = packetData.position()
val data = packetData.array()
if (data[position].toInt() != 0x16 /* handshake */) return false
if (data[1 + position].toInt() != 0x03) return false
return if (data[5 + position].toInt() != 0x01) false else data[9 + position].toInt() == 0x03 && data[10 + position] >= 0x00 && data[1 + position] <= 0x03
}
/**
* 从TLS Client Hello 解析域名
*/
fun getDomain(buffer: ByteBuffer): String? {
var offset = buffer.position()
val limit = buffer.limit()
//TLS Client Hello
if (buffer[offset].toInt() != 0x16) return null
//Skip 43 byte header
offset += 43
if (limit < (offset + 1)) return null
/**
* 从TLS Client Hello 解析域名
*/
fun getDomain(buffer: ByteBuffer): String? {
var offset = buffer.position()
val limit = buffer.limit()
//TLS Client Hello
if (buffer[offset].toInt() != 0x16) return null
//Skip 43 byte header
offset += 43
if (limit < (offset + 1)) return null
//read session id
val sessionIDLength = buffer[offset++]
offset += sessionIDLength
//read session id
val sessionIDLength = buffer[offset++]
offset += sessionIDLength
//read cipher suites
if (offset + 2 > limit) return null
val cipherSuitesLength = buffer.getShort(offset)
offset += 2
offset += cipherSuitesLength
//read cipher suites
if (offset + 2 > limit) return null
val cipherSuitesLength = buffer.getShort(offset)
offset += 2
offset += cipherSuitesLength
//read Compression method.
if (offset + 1 > limit) return null
val compressionMethodLength = buffer[offset++].toInt() and 0xFF
offset += compressionMethodLength
if (offset > limit) return null
//read Compression method.
if (offset + 1 > limit) return null
val compressionMethodLength = buffer[offset++].toInt() and 0xFF
offset += compressionMethodLength
if (offset > limit) return null
//read Extensions
if (offset + 2 > limit) return null
//read Extensions
if (offset + 2 > limit) return null
val extensionsLength = buffer.getShort(offset)
offset += 2
if (offset + extensionsLength > limit) return null
val extensionsLength = buffer.getShort(offset)
offset += 2
if (offset + extensionsLength > limit) return null
var end: Int = offset + extensionsLength
end = min(end, limit)
while (offset + 4 <= end) {
val extensionType = buffer.getShort(offset)
val extensionLength = buffer.getShort(offset + 2)
offset += 4
//server_name
if (extensionType.toInt() == 0) {
if (offset + 5 > limit) return null
val serverNameListLength = buffer.getShort(offset)
offset += 2
if (offset > limit) return null
if (offset + serverNameListLength > limit) return null
var end: Int = offset + extensionsLength
end = min(end, limit)
while (offset + 4 <= end) {
val extensionType = buffer.getShort(offset)
val extensionLength = buffer.getShort(offset + 2)
offset += 4
//server_name
if (extensionType.toInt() == 0) {
if (offset + 5 > limit) return null
val serverNameListLength = buffer.getShort(offset)
offset += 2
if (offset > limit) return null
if (offset + serverNameListLength > limit) return null
val serverNameType = buffer[offset++]
val serverNameLength = buffer.getShort(offset)
offset += 2
if (offset > limit || serverNameType.toInt() != 0) return null
if (offset + serverNameLength > limit) return null
val serverNameBytes = ByteArray(serverNameLength.toInt())
buffer.get(serverNameBytes)
return String(serverNameBytes)
} else {
offset += extensionLength
}
val serverNameType = buffer[offset++]
val serverNameLength = buffer.getShort(offset)
offset += 2
if (offset > limit || serverNameType.toInt() != 0) return null
if (offset + serverNameLength > limit) return null
val serverNameBytes = ByteArray(serverNameLength.toInt())
buffer.get(serverNameBytes)
return String(serverNameBytes)
} else {
offset += extensionLength
}
return null
}
return null
}
}

View File

@@ -133,6 +133,7 @@ class Server extends Network {
void ssl(ChannelContext channelContext, Channel channel, Uint8List data) async {
var hostAndPort = channelContext.host;
try {
hostAndPort?.scheme = HostAndPort.httpsScheme;
if (hostAndPort == null && TLS.getDomain(data) != null) {
hostAndPort = HostAndPort.host(TLS.getDomain(data)!, 443);
}
@@ -140,7 +141,7 @@ class Server extends Network {
Channel? remoteChannel = channelContext.serverChannel;
if (HostFilter.filter(hostAndPort?.host)) {
if (HostFilter.filter(hostAndPort?.host) || !configuration.enableSsl) {
remoteChannel = remoteChannel ?? await channelContext.connectServerChannel(hostAndPort!, RelayHandler(channel));
relay(channel, remoteChannel);
channel.pipeline.channelRead(channelContext, channel, data);