using System; using System.Threading; using System.Threading.Tasks; namespace DNS.Protocol.Utils { public static class TaskExtensions { public static async Task WithCancellation(this Task task, CancellationToken token) { TaskCompletionSource tcs = new TaskCompletionSource(); CancellationTokenRegistration registration = token.Register(src => { ((TaskCompletionSource)src).TrySetResult(true); }, tcs); using (registration) { if (await Task.WhenAny(task, tcs.Task) != task) { throw new OperationCanceledException(token); } } return await task; } public static async Task WithCancellationTimeout(this Task task, TimeSpan timeout, CancellationToken cancellationToken = default(CancellationToken)) { using (CancellationTokenSource timeoutSource = new CancellationTokenSource(timeout)) using (CancellationTokenSource linkSource = CancellationTokenSource.CreateLinkedTokenSource(timeoutSource.Token, cancellationToken)) { return await task.WithCancellation(linkSource.Token); } } } }