diff --git a/src/Snap.Hutao/Snap.Hutao/Core/LifeCycle/InterProcess/Model/ElevationStatusResponse.cs b/src/Snap.Hutao/Snap.Hutao/Core/LifeCycle/InterProcess/Model/ElevationStatusResponse.cs index 44828ded..d589f69d 100644 --- a/src/Snap.Hutao/Snap.Hutao/Core/LifeCycle/InterProcess/Model/ElevationStatusResponse.cs +++ b/src/Snap.Hutao/Snap.Hutao/Core/LifeCycle/InterProcess/Model/ElevationStatusResponse.cs @@ -5,5 +5,10 @@ namespace Snap.Hutao.Core.LifeCycle.InterProcess.Model; internal sealed class ElevationStatusResponse { + public ElevationStatusResponse(bool isElevated) + { + IsElevated = isElevated; + } + public bool IsElevated { get; set; } } diff --git a/src/Snap.Hutao/Snap.Hutao/Core/LifeCycle/InterProcess/PipeStreamExtension.cs b/src/Snap.Hutao/Snap.Hutao/Core/LifeCycle/InterProcess/PipeStreamExtension.cs index 7a3deadc..81fd8263 100644 --- a/src/Snap.Hutao/Snap.Hutao/Core/LifeCycle/InterProcess/PipeStreamExtension.cs +++ b/src/Snap.Hutao/Snap.Hutao/Core/LifeCycle/InterProcess/PipeStreamExtension.cs @@ -2,47 +2,83 @@ // Licensed under the MIT license. using Snap.Hutao.Core.ExceptionService; +using System.Buffers; using System.IO.Hashing; using System.IO.Pipes; +using System.Runtime.CompilerServices; namespace Snap.Hutao.Core.LifeCycle.InterProcess; internal static class PipeStreamExtension { - public static unsafe byte[] GetValidatedContent(this PipeStream stream, PipePacketHeader* header) + public static TData? ReadJsonContent(this PipeStream stream, ref readonly PipePacketHeader header) { - byte[] content = new byte[header->ContentLength]; - stream.ReadAtLeast(content, header->ContentLength, false); - HutaoException.ThrowIf(XxHash64.HashToUInt64(content) != header->Checksum, "PipePacket Content Hash incorrect"); - return content; + using (IMemoryOwner memoryOwner = MemoryPool.Shared.Rent(header.ContentLength)) + { + Span content = memoryOwner.Memory.Span[..header.ContentLength]; + stream.ReadExactly(content); + + HutaoException.ThrowIf(XxHash64.HashToUInt64(content) != header.Checksum, "PipePacket Content Hash incorrect"); + return JsonSerializer.Deserialize(content); + } } - public static unsafe PipePacketHeader ReadPacket(this PipeStream stream, out TData? data) + public static unsafe void ReadPacket(this PipeStream stream, out PipePacketHeader header, out TData? data) where TData : class { data = default; - Span headerSpan = stackalloc byte[sizeof(PipePacketHeader)]; - stream.ReadExactly(headerSpan); - fixed (byte* pHeader = headerSpan) + stream.ReadPacket(out header); + if (header.ContentType is PipePacketContentType.Json) { - PipePacketHeader* header = (PipePacketHeader*)pHeader; - if (header->ContentType is PipePacketContentType.Json) - { - ReadOnlySpan content = stream.GetValidatedContent(header); - data = JsonSerializer.Deserialize(content); - } - - return *header; + data = stream.ReadJsonContent(in header); } } - public static unsafe void WritePacket(this PipeStream stream, PipePacketHeader* header, byte[] content) + [SkipLocalsInit] + public static unsafe void ReadPacket(this PipeStream stream, out PipePacketHeader header) { - header->ContentLength = content.Length; - header->Checksum = XxHash64.HashToUInt64(content); + fixed (PipePacketHeader* pHeader = &header) + { + stream.ReadExactly(new(pHeader, sizeof(PipePacketHeader))); + } + } - stream.Write(new(header, sizeof(PipePacketHeader))); + public static unsafe void WritePacketWithJsonContent(this PipeStream stream, byte version, PipePacketType type, PipePacketCommand command, TData data) + { + PipePacketHeader header = default; + header.Version = version; + header.Type = type; + header.Command = command; + header.ContentType = PipePacketContentType.Json; + + stream.WritePacket(ref header, JsonSerializer.SerializeToUtf8Bytes(data)); + } + + public static unsafe void WritePacket(this PipeStream stream, ref PipePacketHeader header, byte[] content) + { + header.ContentLength = content.Length; + header.Checksum = XxHash64.HashToUInt64(content); + + stream.WritePacket(in header); stream.Write(content); } + + public static unsafe void WritePacket(this PipeStream stream, byte version, PipePacketType type, PipePacketCommand command) + { + PipePacketHeader header = default; + header.Version = version; + header.Type = type; + header.Command = command; + + stream.WritePacket(in header); + } + + public static unsafe void WritePacket(this PipeStream stream, ref readonly PipePacketHeader header) + { + fixed (PipePacketHeader* pHeader = &header) + { + stream.Write(new(pHeader, sizeof(PipePacketHeader))); + } + } } diff --git a/src/Snap.Hutao/Snap.Hutao/Core/LifeCycle/InterProcess/PrivateNamedPipe.cs b/src/Snap.Hutao/Snap.Hutao/Core/LifeCycle/InterProcess/PrivateNamedPipe.cs new file mode 100644 index 00000000..5b02f7f5 --- /dev/null +++ b/src/Snap.Hutao/Snap.Hutao/Core/LifeCycle/InterProcess/PrivateNamedPipe.cs @@ -0,0 +1,10 @@ +// Copyright (c) DGP Studio. All rights reserved. +// Licensed under the MIT license. + +namespace Snap.Hutao.Core.LifeCycle.InterProcess; + +internal static class PrivateNamedPipe +{ + public const int Version = 1; + public const string Name = "Snap.Hutao.PrivateNamedPipe"; +} \ No newline at end of file diff --git a/src/Snap.Hutao/Snap.Hutao/Core/LifeCycle/InterProcess/PrivateNamedPipeClient.cs b/src/Snap.Hutao/Snap.Hutao/Core/LifeCycle/InterProcess/PrivateNamedPipeClient.cs index 34426b94..5f3eef36 100644 --- a/src/Snap.Hutao/Snap.Hutao/Core/LifeCycle/InterProcess/PrivateNamedPipeClient.cs +++ b/src/Snap.Hutao/Snap.Hutao/Core/LifeCycle/InterProcess/PrivateNamedPipeClient.cs @@ -11,62 +11,30 @@ namespace Snap.Hutao.Core.LifeCycle.InterProcess; [ConstructorGenerated] internal sealed partial class PrivateNamedPipeClient : IDisposable { - private readonly NamedPipeClientStream clientStream = new(".", "Snap.Hutao.PrivateNamedPipe", PipeDirection.InOut, PipeOptions.Asynchronous | PipeOptions.WriteThrough); + private readonly NamedPipeClientStream clientStream = new(".", PrivateNamedPipe.Name, PipeDirection.InOut, PipeOptions.Asynchronous | PipeOptions.WriteThrough); private readonly RuntimeOptions runtimeOptions; public unsafe bool TryRedirectActivationTo(AppActivationArguments args) { if (clientStream.TryConnectOnce()) { + clientStream.WritePacket(PrivateNamedPipe.Version, PipePacketType.Request, PipePacketCommand.RequestElevationStatus); + clientStream.ReadPacket(stackalloc byte[sizeof(PipePacketHeader)], out ElevationStatusResponse? response); + ArgumentNullException.ThrowIfNull(response); + + // Prefer elevated instance + if (runtimeOptions.IsElevated && !response.IsElevated) { - // Connect - PipePacketHeader connectPacket = default; - connectPacket.Version = 1; - connectPacket.Type = PipePacketType.Request; - connectPacket.Command = PipePacketCommand.RequestElevationStatus; - - clientStream.Write(new(&connectPacket, sizeof(PipePacketHeader))); - } - - clientStream.ReadPacket(out ElevationStatusResponse? serverElevationStatus); - ArgumentNullException.ThrowIfNull(serverElevationStatus); - - if (runtimeOptions.IsElevated && !serverElevationStatus.IsElevated) - { - // Kill previous instance to use current elevated instance - PipePacketHeader killPacket = default; - killPacket.Version = 1; - killPacket.Type = PipePacketType.SessionTermination; - killPacket.Command = PipePacketCommand.Exit; - - clientStream.Write(new(&killPacket, sizeof(PipePacketHeader))); + // Notify previous instance to exit + clientStream.WritePacket(PrivateNamedPipe.Version, PipePacketType.SessionTermination, PipePacketCommand.Exit); clientStream.Flush(); return false; } - { - // Redirect to previous instance - PipePacketHeader redirectActivationPacket = default; - redirectActivationPacket.Version = 1; - redirectActivationPacket.Type = PipePacketType.Request; - redirectActivationPacket.Command = PipePacketCommand.RedirectActivation; - redirectActivationPacket.ContentType = PipePacketContentType.Json; - - HutaoActivationArguments hutaoArgs = HutaoActivationArguments.FromAppActivationArguments(args, isRedirected: true); - byte[] jsonBytes = JsonSerializer.SerializeToUtf8Bytes(hutaoArgs); - - clientStream.WritePacket(&redirectActivationPacket, jsonBytes); - } - - { - // Terminate session - PipePacketHeader terminationPacket = default; - terminationPacket.Version = 1; - terminationPacket.Type = PipePacketType.SessionTermination; - - clientStream.Write(new(&terminationPacket, sizeof(PipePacketHeader))); - } - + // Redirect to previous instance + HutaoActivationArguments hutaoArgs = HutaoActivationArguments.FromAppActivationArguments(args, isRedirected: true); + clientStream.WritePacketWithJsonContent(PrivateNamedPipe.Version, PipePacketType.Request, PipePacketCommand.RedirectActivation, hutaoArgs); + clientStream.WritePacket(PrivateNamedPipe.Version, PipePacketType.SessionTermination, PipePacketCommand.None); clientStream.Flush(); return true; } diff --git a/src/Snap.Hutao/Snap.Hutao/Core/LifeCycle/InterProcess/PrivateNamedPipeMessageDispatcher.cs b/src/Snap.Hutao/Snap.Hutao/Core/LifeCycle/InterProcess/PrivateNamedPipeMessageDispatcher.cs index 9047a5fd..5b7e9a25 100644 --- a/src/Snap.Hutao/Snap.Hutao/Core/LifeCycle/InterProcess/PrivateNamedPipeMessageDispatcher.cs +++ b/src/Snap.Hutao/Snap.Hutao/Core/LifeCycle/InterProcess/PrivateNamedPipeMessageDispatcher.cs @@ -19,7 +19,7 @@ internal sealed partial class PrivateNamedPipeMessageDispatcher serviceProvider.GetRequiredService().Activate(args); } - public void Exit() + public void ExitApplication() { ITaskContext taskContext = serviceProvider.GetRequiredService(); App app = serviceProvider.GetRequiredService(); diff --git a/src/Snap.Hutao/Snap.Hutao/Core/LifeCycle/InterProcess/PrivateNamedPipeServer.cs b/src/Snap.Hutao/Snap.Hutao/Core/LifeCycle/InterProcess/PrivateNamedPipeServer.cs index f39ee8f2..c466cf62 100644 --- a/src/Snap.Hutao/Snap.Hutao/Core/LifeCycle/InterProcess/PrivateNamedPipeServer.cs +++ b/src/Snap.Hutao/Snap.Hutao/Core/LifeCycle/InterProcess/PrivateNamedPipeServer.cs @@ -35,7 +35,7 @@ internal sealed partial class PrivateNamedPipeServer : IDisposable } serverStream = NamedPipeServerStreamAcl.Create( - "Snap.Hutao.PrivateNamedPipe", + PrivateNamedPipe.Name, PipeDirection.InOut, NamedPipeServerStream.MaxAllowedServerInstances, PipeTransmissionMode.Byte, @@ -77,42 +77,27 @@ internal sealed partial class PrivateNamedPipeServer : IDisposable { while (serverStream.IsConnected && !token.IsCancellationRequested) { - PipePacketHeader header = serverStream.ReadPacket(out HutaoActivationArguments? hutaoArgs); + serverStream.ReadPacket(out PipePacketHeader header); switch ((header.Type, header.Command)) { case (PipePacketType.Request, PipePacketCommand.RequestElevationStatus): - RespondElevationStatus(); + ElevationStatusResponse resp = new(runtimeOptions.IsElevated); + serverStream.WritePacketWithJsonContent(PrivateNamedPipe.Version, PipePacketType.Response, PipePacketCommand.ResponseElevationStatus, resp); + serverStream.Flush(); break; case (PipePacketType.Request, PipePacketCommand.RedirectActivation): + HutaoActivationArguments? hutaoArgs = serverStream.ReadJsonContent(in header); messageDispatcher.RedirectActivation(hutaoArgs); break; case (PipePacketType.SessionTermination, _): serverStream.Disconnect(); if (header.Command is PipePacketCommand.Exit) { - messageDispatcher.Exit(); + messageDispatcher.ExitApplication(); } return; } } - - void RespondElevationStatus() - { - PipePacketHeader elevatedPacket = default; - elevatedPacket.Version = 1; - elevatedPacket.Type = PipePacketType.Response; - elevatedPacket.Command = PipePacketCommand.ResponseElevationStatus; - elevatedPacket.ContentType = PipePacketContentType.Json; - - ElevationStatusResponse resp = new() - { - IsElevated = runtimeOptions.IsElevated, - }; - - byte[] elevatedBytes = JsonSerializer.SerializeToUtf8Bytes(resp); - serverStream.WritePacket(&elevatedPacket, elevatedBytes); - serverStream.Flush(); - } } } \ No newline at end of file