From ddf2aef7b75bb1f09fd93648887fd1e68539bf57 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B6=B5=E6=9B=A6?= Date: Mon, 5 Jan 2026 20:23:42 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E9=87=8D=E6=9E=84=E6=8B=86?= =?UTF-8?q?=E5=88=86=20utils=20=E6=96=87=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- xiaomusic/utils.py | 1502 ------------------------------ xiaomusic/utils/__init__.py | 125 +++ xiaomusic/utils/file_utils.py | 189 ++++ xiaomusic/utils/music_utils.py | 695 ++++++++++++++ xiaomusic/utils/network_utils.py | 448 +++++++++ xiaomusic/utils/system_utils.py | 299 ++++++ xiaomusic/utils/text_utils.py | 264 ++++++ 7 files changed, 2020 insertions(+), 1502 deletions(-) delete mode 100644 xiaomusic/utils.py create mode 100644 xiaomusic/utils/__init__.py create mode 100644 xiaomusic/utils/file_utils.py create mode 100644 xiaomusic/utils/music_utils.py create mode 100644 xiaomusic/utils/network_utils.py create mode 100644 xiaomusic/utils/system_utils.py create mode 100644 xiaomusic/utils/text_utils.py diff --git a/xiaomusic/utils.py b/xiaomusic/utils.py deleted file mode 100644 index 87c6e5e..0000000 --- a/xiaomusic/utils.py +++ /dev/null @@ -1,1502 +0,0 @@ -#!/usr/bin/env python3 -from __future__ import annotations - -import asyncio -import base64 -import copy -import difflib -import hashlib -import io -import json -import logging -import mimetypes -import os -import platform -import random -import re -import shutil -import string -import subprocess -import tempfile -import time -import urllib.parse -from collections import OrderedDict -from collections.abc import AsyncIterator -from dataclasses import asdict, dataclass -from http.cookies import SimpleCookie -from pathlib import Path -from time import sleep -from urllib.parse import parse_qs, urlparse - -import aiohttp -import edge_tts -import mutagen -from mutagen.asf import ASF -from mutagen.flac import FLAC -from mutagen.id3 import ( - APIC, - ID3, - TALB, - TCON, - TDRC, - TIT2, - TPE1, - USLT, - Encoding, - TextFrame, - TimeStampTextFrame, -) -from mutagen.mp3 import MP3 -from mutagen.mp4 import MP4 -from mutagen.oggvorbis import OggVorbis -from mutagen.wave import WAVE -from mutagen.wavpack import WavPack -from opencc import OpenCC -from PIL import Image -from requests.utils import cookiejar_from_dict - -from xiaomusic.const import SUPPORT_MUSIC_TYPE - -log = logging.getLogger(__package__) - -cc = OpenCC("t2s") # convert from Traditional Chinese to Simplified Chinese - - -### HELP FUNCTION ### -def parse_cookie_string(cookie_string): - cookie = SimpleCookie() - cookie.load(cookie_string) - cookies_dict = {k: m.value for k, m in cookie.items()} - return cookiejar_from_dict(cookies_dict, cookiejar=None, overwrite=True) - - -_no_elapse_chars = re.compile(r"([「」『』《》“”'\"()()]|(? float: - # for simplicity, we use a fixed speed - speed = 4.5 # this value is picked by trial and error - # Exclude quotes and brackets that do not affect the total elapsed time - return len(_no_elapse_chars.sub("", text)) / speed - - -_ending_punctuations = ("。", "?", "!", ";", ".", "?", "!", ";") - - -async def split_sentences(text_stream: AsyncIterator[str]) -> AsyncIterator[str]: - cur = "" - async for text in text_stream: - cur += text - if cur.endswith(_ending_punctuations): - yield cur - cur = "" - if cur: - yield cur - - -### for edge-tts utils ### -def find_key_by_partial_string(dictionary: dict[str, str], partial_key: str) -> str: - for key, value in dictionary.items(): - if key in partial_key: - return value - - -def validate_proxy(proxy_str: str) -> bool: - """Do a simple validation of the http proxy string.""" - - parsed = urlparse(proxy_str) - if parsed.scheme not in ("http", "https"): - raise ValueError("Proxy scheme must be http or https") - if not (parsed.hostname and parsed.port): - raise ValueError("Proxy hostname and port must be set") - - return True - - -# 模糊搜索 -def fuzzyfinder(user_input, collection, extra_search_index=None): - return find_best_match( - user_input, collection, cutoff=0.1, n=10, extra_search_index=extra_search_index - ) - - -def traditional_to_simple(to_convert: str): - return cc.convert(to_convert) - - -# 关键词检测 -def keyword_detection(user_input, str_list, n): - # 过滤包含关键字的字符串 - matched, remains = [], [] - for item in str_list: - if user_input in item: - matched.append(item) - else: - remains.append(item) - - matched = sorted( - matched, - key=lambda s: difflib.SequenceMatcher(None, s, user_input).ratio(), - reverse=True, # 降序排序,越相似的越靠前 - ) - - # 如果 n 是 -1,如果 n 大于匹配的数量,返回所有匹配的结果 - if n == -1 or n > len(matched): - return matched, remains - - # 选择前 n 个匹配的结果 - remains = matched[n:] + remains - return matched[:n], remains - - -def real_search(prompt, candidates, cutoff, n): - matches, remains = keyword_detection(prompt, candidates, n=n) - if len(matches) < n: - # 如果没有准确关键词匹配,开始模糊匹配 - matches += difflib.get_close_matches(prompt, remains, n=n, cutoff=cutoff) - return matches - - -def find_best_match(user_input, collection, cutoff=0.6, n=1, extra_search_index=None): - lower_collection = { - traditional_to_simple(item.lower()): item for item in collection - } - user_input = traditional_to_simple(user_input.lower()) - matches = real_search(user_input, lower_collection.keys(), cutoff, n) - cur_matched_collection = [lower_collection[match] for match in matches] - if len(matches) >= n or extra_search_index is None: - return cur_matched_collection[:n] - - # 如果数量不满足,继续搜索 - lower_extra_search_index = { - traditional_to_simple(k.lower()): v - for k, v in extra_search_index.items() - if v not in cur_matched_collection - } - matches = real_search(user_input, lower_extra_search_index.keys(), cutoff, n) - cur_matched_collection += [lower_extra_search_index[match] for match in matches] - return cur_matched_collection[:n] - - -# 歌曲排序 -def custom_sort_key(s): - # 使用正则表达式分别提取字符串的数字前缀和数字后缀 - prefix_match = re.match(r"^(\d+)", s) - suffix_match = re.search(r"(\d+)$", s) - - numeric_prefix = int(prefix_match.group(0)) if prefix_match else None - numeric_suffix = int(suffix_match.group(0)) if suffix_match else None - - if numeric_prefix is not None: - # 如果前缀是数字,先按前缀数字排序,再按整个字符串排序 - return (0, numeric_prefix, s) - elif numeric_suffix is not None: - # 如果后缀是数字,先按前缀字符排序,再按后缀数字排序 - return (1, s[: suffix_match.start()], numeric_suffix) - else: - # 如果前缀和后缀都不是数字,按字典序排序 - return (2, s) - - -def _get_depth_path(root, directory, depth): - # 计算当前目录的深度 - relative_path = root[len(directory) :].strip(os.sep) - path_parts = relative_path.split(os.sep) - if len(path_parts) >= depth: - return os.path.join(directory, *path_parts[:depth]) - else: - return root - - -def _append_files_result(result, root, joinpath, files, support_extension): - dir_name = os.path.basename(root) - if dir_name not in result: - result[dir_name] = [] - for file in files: - # 过滤隐藏文件 - if file.startswith("."): - continue - # 过滤文件后缀 - (name, extension) = os.path.splitext(file) - if extension.lower() not in support_extension: - continue - - result[dir_name].append(os.path.join(joinpath, file)) - - -def traverse_music_directory(directory, depth, exclude_dirs, support_extension): - result = {} - for root, dirs, files in os.walk(directory, followlinks=True): - # 忽略排除的目录 - dirs[:] = [d for d in dirs if d not in exclude_dirs] - - # 计算当前目录的深度 - current_depth = root[len(directory) :].count(os.sep) + 1 - if current_depth > depth: - depth_path = _get_depth_path(root, directory, depth - 1) - _append_files_result(result, depth_path, root, files, support_extension) - else: - _append_files_result(result, root, root, files, support_extension) - return result - - -async def downloadfile(url): - # 清理和验证URL - # 解析URL - parsed_url = urlparse(url) - # 基础验证:仅允许HTTP和HTTPS协议 - if parsed_url.scheme not in ("http", "https"): - raise Warning( - f"Invalid URL scheme: {parsed_url.scheme}. Only HTTP and HTTPS are allowed." - ) - # 构建目标URL - cleaned_url = parsed_url.geturl() - - # 使用 aiohttp 创建一个客户端会话来发起请求 - async with aiohttp.ClientSession() as session: - async with session.get( - cleaned_url, timeout=5 - ) as response: # 增加超时以避免长时间挂起 - # 如果响应不是200,引发异常 - response.raise_for_status() - # 读取响应文本 - text = await response.text() - return text - - -def is_mp3(url): - mt = mimetypes.guess_type(url) - if mt and mt[0] == "audio/mpeg": - return True - return False - - -def is_m4a(url): - return url.endswith(".m4a") - - -async def _get_web_music_duration(session, url, config, start=0, end=500): - """ - 异步获取网络音乐文件的部分内容并估算其时长。 - - 通过请求 URL 的前几个字节(默认 0-500)下载部分文件, - 写入临时文件后调用本地工具(如 ffprobe)获取音频时长。 - - :param session: aiohttp.ClientSession 实例 - :param url: 音乐文件的 URL 地址 - :param config: 包含配置信息的对象(如 ffmpeg 路径) - :param start: 请求的起始字节位置 - :param end: 请求的结束字节位置 - :return: 返回音频的持续时间(秒),如果失败则返回 0 - """ - duration = 0 - # 设置请求头 Range,只请求部分内容(用于快速获取元数据) - headers = {"Range": f"bytes={start}-{end}"} - - # 使用 aiohttp 异步发起 GET 请求,获取部分音频内容 - async with session.get(url, headers=headers) as response: - array_buffer = await response.read() # 读取响应的二进制内容 - - # 创建一个命名的临时文件,并禁用自动删除(以便后续读取) - with tempfile.NamedTemporaryFile(delete=False) as tmp: - tmp.write(array_buffer) # 将下载的部分内容写入临时文件 - tmp_path = tmp.name # 获取该临时文件的真实路径 - - try: - # 调用 get_local_music_duration 并传入文件路径,而不是文件对象 - duration = await get_local_music_duration(tmp_path, config) - except Exception as e: - log.error(f"Error _get_web_music_duration: {e}") - finally: - # 手动删除临时文件,避免残留 - os.unlink(tmp_path) - - return duration - - -async def get_web_music_duration(url, config): - duration = 0 - try: - parsed_url = urlparse(url) - file_path = parsed_url.path - _, extension = os.path.splitext(file_path) - if extension.lower() not in SUPPORT_MUSIC_TYPE: - cleaned_url = parsed_url.geturl() - async with aiohttp.ClientSession() as session: - async with session.get( - cleaned_url, - allow_redirects=True, - headers={ - "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_10_1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/39.0.2171.95 Safari/537.36" - }, - ) as response: - url = str(response.url) - # 设置总超时时间为3秒 - timeout = aiohttp.ClientTimeout(total=3) - async with aiohttp.ClientSession(timeout=timeout) as session: - duration = await _get_web_music_duration( - session, url, config, start=0, end=500 - ) - if duration <= 0: - duration = await _get_web_music_duration( - session, url, config, start=0, end=3000 - ) - except Exception as e: - log.error(f"Error get_web_music_duration: {e}") - return duration, url - - -# 获取文件播放时长 -async def get_local_music_duration(filename, config): - duration = 0 - if config.get_duration_type == "ffprobe": - duration = get_duration_by_ffprobe(filename, config.ffmpeg_location) - else: - duration = await get_duration_by_mutagen(filename) - - # 换个方式重试一次 - if duration == 0: - if config.get_duration_type != "ffprobe": - duration = get_duration_by_ffprobe(filename, config.ffmpeg_location) - else: - duration = await get_duration_by_mutagen(filename) - - return duration - - -async def get_duration_by_mutagen(file_path): - duration = 0 - try: - loop = asyncio.get_event_loop() - if is_mp3(file_path): - m = await loop.run_in_executor(None, mutagen.mp3.MP3, file_path) - else: - m = await loop.run_in_executor(None, mutagen.File, file_path) - duration = m.info.length - except Exception as e: - log.warning(f"Error getting local music {file_path} duration: {e}") - return duration - - -def get_duration_by_ffprobe(file_path, ffmpeg_location): - duration = 0 - try: - # 构造 ffprobe 命令参数 - cmd_args = [ - os.path.join(ffmpeg_location, "ffprobe"), - "-v", - "error", # 只输出错误信息,避免混杂在其他输出中 - "-show_entries", - "format=duration", # 仅显示时长 - "-of", - "json", # 以 JSON 格式输出 - file_path, - ] - - # 输出待执行的完整命令 - full_command = " ".join(cmd_args) - log.info(f"待执行的完整命令 ffprobe command: {full_command}") - - # 使用 ffprobe 获取文件的元数据,并以 JSON 格式输出 - result = subprocess.run( - cmd_args, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - text=True, - ) - - # 输出命令执行结果 - log.info( - f"命令执行结果 command result - return code: {result.returncode}, stdout: {result.stdout}" - ) - - # 解析 JSON 输出 - ffprobe_output = json.loads(result.stdout) - - # 获取时长 - duration = float(ffprobe_output["format"]["duration"]) - log.info( - f"Successfully extracted duration: {duration} seconds for file: {file_path}" - ) - - except Exception as e: - log.warning(f"Error getting local music {file_path} duration: {e}") - return duration - - -def get_random(length): - return "".join(random.sample(string.ascii_letters + string.digits, length)) - - -# 深拷贝把敏感数据设置为* -def deepcopy_data_no_sensitive_info(data, fields_to_anonymize=None): - if fields_to_anonymize is None: - fields_to_anonymize = [ - "account", - "password", - "httpauth_username", - "httpauth_password", - ] - - copy_data = copy.deepcopy(data) - - # 检查copy_data是否是字典或具有属性的对象 - if isinstance(copy_data, dict): - # 对字典进行处理 - for field in fields_to_anonymize: - if field in copy_data: - copy_data[field] = "******" - else: - # 对对象进行处理 - for field in fields_to_anonymize: - if hasattr(copy_data, field): - setattr(copy_data, field, "******") - - return copy_data - - -# k1:v1,k2:v2 -def parse_str_to_dict(s, d1=",", d2=":"): - # 初始化一个空字典 - result = {} - parts = s.split(d1) - - for part in parts: - # 根据冒号切割 - subparts = part.split(d2) - if len(subparts) == 2: # 防止数据不是成对出现 - k, v = subparts - result[k] = v - - return result - - -# remove mp3 file id3 tag and padding to reduce delay -def no_padding(info): - # this will remove all padding - return 0 - - -def remove_id3_tags(input_file: str, config) -> str: - audio = MP3(input_file, ID3=ID3) - - # 检查是否存在ID3 v2.3或v2.4标签 - if not ( - audio.tags - and (audio.tags.version == (2, 3, 0) or audio.tags.version == (2, 4, 0)) - ): - return None - - music_path = config.music_path - temp_dir = config.temp_dir - - # 构造新文件的路径 - out_file_name = os.path.splitext(os.path.basename(input_file))[0] - out_file_path = os.path.join(temp_dir, f"{out_file_name}.mp3") - relative_path = os.path.relpath(out_file_path, music_path) - - # 路径相同的情况 - input_absolute_path = os.path.abspath(input_file) - output_absolute_path = os.path.abspath(out_file_path) - if input_absolute_path == output_absolute_path: - log.info(f"File {input_file} = {out_file_path} . Skipping remove_id3_tags.") - return None - - # 检查目标文件是否存在 - if os.path.exists(out_file_path): - log.info(f"File {out_file_path} already exists. Skipping remove_id3_tags.") - return relative_path - - # 开始去除(不再需要检查) - # 拷贝文件 - shutil.copy(input_file, out_file_path) - outaudio = MP3(out_file_path, ID3=ID3) - # 删除ID3标签 - outaudio.delete() - # 保存修改后的文件 - outaudio.save(padding=no_padding) - log.info(f"File {out_file_path} remove_id3_tags ok.") - return relative_path - - -def convert_file_to_mp3(input_file: str, config) -> str: - music_path = config.music_path - temp_dir = config.temp_dir - - out_file_name = os.path.splitext(os.path.basename(input_file))[0] - out_file_path = os.path.join(temp_dir, f"{out_file_name}.mp3") - relative_path = os.path.relpath(out_file_path, music_path) - - # 路径相同的情况 - input_absolute_path = os.path.abspath(input_file) - output_absolute_path = os.path.abspath(out_file_path) - if input_absolute_path == output_absolute_path: - log.info(f"File {input_file} = {out_file_path} . Skipping convert_file_to_mp3.") - return None - - absolute_music_path = os.path.abspath(music_path) - if not input_absolute_path.startswith(absolute_music_path): - log.error(f"Invalid input file path: {input_file}") - return None - - # 检查目标文件是否存在 - if os.path.exists(out_file_path): - log.info(f"File {out_file_path} already exists. Skipping convert_file_to_mp3.") - return relative_path - - # 检查是否存在 loudnorm 参数 - loudnorm_args = [] - if config.loudnorm: - loudnorm_args = ["-af", config.loudnorm] - - command = [ - os.path.join(config.ffmpeg_location, "ffmpeg"), - "-i", - input_absolute_path, - "-f", - "mp3", - "-vn", - "-y", - *loudnorm_args, - out_file_path, - ] - - try: - subprocess.run(command, check=True) - except subprocess.CalledProcessError as e: - log.exception(f"Error during conversion: {e}") - return None - - log.info(f"File {input_file} to {out_file_path} convert_file_to_mp3 ok.") - return relative_path - - -chinese_to_arabic = { - "零": 0, - "一": 1, - "二": 2, - "三": 3, - "四": 4, - "五": 5, - "六": 6, - "七": 7, - "八": 8, - "九": 9, - "十": 10, - "百": 100, - "千": 1000, - "万": 10000, - "亿": 100000000, -} - - -def chinese_to_number(chinese): - result = 0 - unit = 1 - num = 0 - # 处理特殊情况:以"十"开头时,在前面加"一" - if chinese.startswith("十"): - chinese = "一" + chinese - - # 如果只有一个字符且是单位,直接返回其值 - if len(chinese) == 1 and chinese_to_arabic[chinese] >= 10: - return chinese_to_arabic[chinese] - for char in reversed(chinese): - if char in chinese_to_arabic: - val = chinese_to_arabic[char] - if val >= 10: - if val > unit: - unit = val - else: - unit *= val - else: - num += val * unit - result += num - - return result - - -def list2str(li, verbose=False): - if len(li) > 5 and not verbose: - return f"{li[:2]} ... {li[-2:]} with len: {len(li)}" - else: - return f"{li}" - - -async def get_latest_version(package_name: str) -> str: - url = f"https://pypi.org/pypi/{package_name}/json" - async with aiohttp.ClientSession() as session: - async with session.get(url) as response: - if response.status == 200: - data = await response.json() - return data["info"]["version"] - else: - return None - - -@dataclass -class Metadata: - title: str = "" - artist: str = "" - album: str = "" - year: str = "" - genre: str = "" - picture: str = "" - lyrics: str = "" - - def __init__(self, info=None): - if info: - self.title = info.get("title", "") - self.artist = info.get("artist", "") - self.album = info.get("album", "") - self.year = info.get("year", "") - self.genre = info.get("genre", "") - self.picture = info.get("picture", "") - self.lyrics = info.get("lyrics", "") - - -def _get_alltag_value(tags, k): - v = tags.getall(k) - if len(v) > 0: - return _to_utf8(v[0]) - return "" - - -def _get_tag_value(tags, k): - if k not in tags: - return "" - v = tags[k] - return _to_utf8(v) - - -def _to_utf8(v): - if isinstance(v, TextFrame) and not isinstance(v, TimeStampTextFrame): - old_ts = "".join(v.text) - if v.encoding == Encoding.LATIN1: - bs = old_ts.encode("latin1") - ts = bs.decode("GBK", errors="ignore") - return ts - return old_ts - elif isinstance(v, list): - return "".join(str(item) for item in v) - return str(v) - - -def save_picture_by_base64(picture_base64_data, save_root, file_path): - try: - picture_data = base64.b64decode(picture_base64_data) - except (TypeError, ValueError) as e: - log.exception(f"Error decoding base64 data: {e}") - return None - return _save_picture(picture_data, save_root, file_path) - - -def _save_picture(picture_data, save_root, file_path): - # 计算文件名的哈希值 - file_hash = hashlib.md5(file_path.encode("utf-8")).hexdigest() - # 创建目录结构 - dir_path = os.path.join(save_root, file_hash[-6:]) - os.makedirs(dir_path, exist_ok=True) - - # 保存图片 - filename = os.path.basename(file_path) - (name, _) = os.path.splitext(filename) - picture_path = os.path.join(dir_path, f"{name}.jpg") - - try: - _resize_save_image(picture_data, picture_path) - except Exception as e: - log.warning(f"Error _resize_save_image: {e}") - return picture_path - - -def _resize_save_image(image_bytes, save_path, max_size=300): - # 将 bytes 转换为 PIL Image 对象 - image = None - try: - image = Image.open(io.BytesIO(image_bytes)) - image = image.convert("RGB") - except Exception as e: - log.warning(f"Error _resize_save_image: {e}") - return - - # 获取原始尺寸 - original_width, original_height = image.size - - # 如果图片的宽度和高度都小于 max_size,则直接保存原始图片 - if original_width <= max_size and original_height <= max_size: - image.save(save_path, format="JPEG") - return - - # 计算缩放比例,保持等比缩放 - scaling_factor = min(max_size / original_width, max_size / original_height) - - # 计算新的尺寸 - new_width = int(original_width * scaling_factor) - new_height = int(original_height * scaling_factor) - - resized_image = image.resize((new_width, new_height), Image.Resampling.LANCZOS) - resized_image.save(save_path, format="JPEG") - return save_path - - -def extract_audio_metadata(file_path, save_root): - metadata = Metadata() - - audio = None - try: - audio = mutagen.File(file_path) - except Exception as e: - log.warning(f"Error extract_audio_metadata file: {file_path} {e}") - if audio is None: - return asdict(metadata) - - tags = audio.tags - if tags is None: - return asdict(metadata) - - if isinstance(audio, MP3): - metadata.title = _get_tag_value(tags, "TIT2") - metadata.artist = _get_tag_value(tags, "TPE1") - metadata.album = _get_tag_value(tags, "TALB") - metadata.year = _get_tag_value(tags, "TDRC") - metadata.genre = _get_tag_value(tags, "TCON") - metadata.lyrics = _get_alltag_value(tags, "USLT") - for tag in tags.values(): - if isinstance(tag, APIC): - metadata.picture = _save_picture(tag.data, save_root, file_path) - break - - elif isinstance(audio, FLAC): - metadata.title = _get_tag_value(tags, "TITLE") - metadata.artist = _get_tag_value(tags, "ARTIST") - metadata.album = _get_tag_value(tags, "ALBUM") - metadata.year = _get_tag_value(tags, "DATE") - metadata.genre = _get_tag_value(tags, "GENRE") - if audio.pictures: - metadata.picture = _save_picture( - audio.pictures[0].data, save_root, file_path - ) - if "lyrics" in audio: - metadata.lyrics = audio["lyrics"][0] - - elif isinstance(audio, MP4): - metadata.title = _get_tag_value(tags, "\xa9nam") - metadata.artist = _get_tag_value(tags, "\xa9ART") - metadata.album = _get_tag_value(tags, "\xa9alb") - metadata.year = _get_tag_value(tags, "\xa9day") - metadata.genre = _get_tag_value(tags, "\xa9gen") - if "covr" in tags: - metadata.picture = _save_picture(tags["covr"][0], save_root, file_path) - - elif isinstance(audio, OggVorbis): - metadata.title = _get_tag_value(tags, "TITLE") - metadata.artist = _get_tag_value(tags, "ARTIST") - metadata.album = _get_tag_value(tags, "ALBUM") - metadata.year = _get_tag_value(tags, "DATE") - metadata.genre = _get_tag_value(tags, "GENRE") - if "metadata_block_picture" in tags: - picture = json.loads(base64.b64decode(tags["metadata_block_picture"][0])) - metadata.picture = _save_picture( - base64.b64decode(picture["data"]), save_root, file_path - ) - - elif isinstance(audio, ASF): - metadata.title = _get_tag_value(tags, "Title") - metadata.artist = _get_tag_value(tags, "Author") - metadata.album = _get_tag_value(tags, "WM/AlbumTitle") - metadata.year = _get_tag_value(tags, "WM/Year") - metadata.genre = _get_tag_value(tags, "WM/Genre") - if "WM/Picture" in tags: - metadata.picture = _save_picture( - tags["WM/Picture"][0].value, save_root, file_path - ) - - elif isinstance(audio, WavPack): - metadata.title = _get_tag_value(tags, "Title") - metadata.artist = _get_tag_value(tags, "Artist") - metadata.album = _get_tag_value(tags, "Album") - metadata.year = _get_tag_value(tags, "Year") - metadata.genre = _get_tag_value(tags, "Genre") - if audio.pictures: - metadata.picture = _save_picture( - audio.pictures[0].data, save_root, file_path - ) - - elif isinstance(audio, WAVE): - metadata.title = _get_tag_value(tags, "Title") - metadata.artist = _get_tag_value(tags, "Artist") - - return asdict(metadata) - - -def set_music_tag_to_file(file_path, info): - audio = mutagen.File(file_path, easy=True) - if audio is None: - log.error(f"Unable to open file {file_path}") - return "Unable to open file" - - if isinstance(audio, MP3): - _set_mp3_tags(audio, info) - elif isinstance(audio, FLAC): - _set_flac_tags(audio, info) - elif isinstance(audio, MP4): - _set_mp4_tags(audio, info) - elif isinstance(audio, OggVorbis): - _set_ogg_tags(audio, info) - elif isinstance(audio, ASF): - _set_asf_tags(audio, info) - elif isinstance(audio, WAVE): - _set_wave_tags(audio, info) - else: - log.error(f"Unsupported file type for {file_path}") - return "Unsupported file type" - - try: - audio.save() - log.info(f"Tags saved successfully to {file_path}") - return "OK" - except Exception as e: - log.exception(f"Error saving tags: {e}") - return "Error saving tags" - - -def _set_mp3_tags(audio, info): - audio.tags = ID3() - audio["TIT2"] = TIT2(encoding=3, text=info.title) - audio["TPE1"] = TPE1(encoding=3, text=info.artist) - audio["TALB"] = TALB(encoding=3, text=info.album) - audio["TDRC"] = TDRC(encoding=3, text=info.year) - audio["TCON"] = TCON(encoding=3, text=info.genre) - - # 使用 USLT 存储歌词 - if info.lyrics: - audio["USLT"] = USLT(encoding=3, lang="eng", text=info.lyrics) - - # 添加封面图片 - if info.picture: - with open(info.picture, "rb") as img_file: - image_data = img_file.read() - audio["APIC"] = APIC( - encoding=3, mime="image/jpeg", type=3, desc="Cover", data=image_data - ) - audio.save() # 保存修改 - - -def _set_flac_tags(audio, info): - audio["TITLE"] = info.title - audio["ARTIST"] = info.artist - audio["ALBUM"] = info.album - audio["DATE"] = info.year - audio["GENRE"] = info.genre - if info.lyrics: - audio["LYRICS"] = info.lyrics - if info.picture: - with open(info.picture, "rb") as img_file: - image_data = img_file.read() - audio.add_picture(image_data) - - -def _set_mp4_tags(audio, info): - audio["nam"] = info.title - audio["ART"] = info.artist - audio["alb"] = info.album - audio["day"] = info.year - audio["gen"] = info.genre - if info.picture: - with open(info.picture, "rb") as img_file: - image_data = img_file.read() - audio["covr"] = [image_data] - - -def _set_ogg_tags(audio, info): - audio["TITLE"] = info.title - audio["ARTIST"] = info.artist - audio["ALBUM"] = info.album - audio["DATE"] = info.year - audio["GENRE"] = info.genre - if info.lyrics: - audio["LYRICS"] = info.lyrics - if info.picture: - with open(info.picture, "rb") as img_file: - image_data = img_file.read() - audio["metadata_block_picture"] = base64.b64encode(image_data).decode() - - -def _set_asf_tags(audio, info): - audio["Title"] = info.title - audio["Author"] = info.artist - audio["WM/AlbumTitle"] = info.album - audio["WM/Year"] = info.year - audio["WM/Genre"] = info.genre - if info.picture: - with open(info.picture, "rb") as img_file: - image_data = img_file.read() - audio["WM/Picture"] = image_data - - -def _set_wave_tags(audio, info): - audio["Title"] = info.title - audio["Artist"] = info.artist - - -async def check_bili_fav_list(url): - bvid_info = {} - parsed_url = urlparse(url) - path = parsed_url.path - # 提取查询参数 - query_params = parse_qs(parsed_url.query) - if parsed_url.hostname == "space.bilibili.com": - if "/favlist" in path: - lid = query_params.get("fid", [None])[0] - type = query_params.get("ctype", [None])[0] - if type == "11": - type = "create" - elif type == "21": - type = "collect" - else: - raise ValueError("当前只支持合集和收藏夹") - elif "/lists/" in path: - parts = path.split("/") - if len(parts) >= 4 and "?" in url: - lid = parts[3] # 提取 lid - type = query_params.get("type", [None])[0] - # https://api.bilibili.com/x/polymer/web-space/seasons_archives_list?season_id={lid}&page_size=30&page_num=1 - page_size = 100 - page_num = 1 - headers = { - "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0 Safari/537.36", - "Accept": "application/json, text/plain, */*", - "Referer": url, - "Origin": "https://space.bilibili.com", - } - async with aiohttp.ClientSession(headers=headers) as session: - if type == "season" or type == "collect": - while True: - list_url = f"https://api.bilibili.com/x/polymer/web-space/seasons_archives_list?season_id={lid}&page_size={page_size}&page_num={page_num}" - async with session.get(list_url) as response: - if response.status != 200: - raise Exception(f"Failed to fetch data from {list_url}") - data = await response.json() - archives = data.get("data", {}).get("archives", []) - if not archives: - break - for archive in archives: - bvid = archive.get("bvid", None) - title = archive.get("title", None) - bvid_info[bvid] = title - - if len(archives) < page_size: - break - page_num += 1 - sleep(1) - elif type == "create": - while True: - list_url = f"https://api.bilibili.com/x/v3/fav/resource/list?media_id={lid}&pn={page_num}&ps={page_size}&order=mtime" - async with session.get(list_url) as response: - if response.status != 200: - raise Exception(f"Failed to fetch data from {list_url}") - data = await response.json() - medias = data.get("data", {}).get("medias", []) - if not medias: - break - for media in medias: - bvid = media.get("bvid", None) - title = media.get("title", None) - bvurl = f"https://www.bilibili.com/video/{bvid}" - bvid_info[bvurl] = title - - if len(medias) < page_size: - break - page_num += 1 - else: - raise ValueError("当前只支持合集和收藏夹") - return bvid_info - - -# 下载播放列表 -async def download_playlist(config, url, dirname): - title = f"{dirname}/%(title)s.%(ext)s" - sbp_args = ( - "yt-dlp", - "--yes-playlist", - "-x", - "--audio-format", - "mp3", - "--audio-quality", - "0", - "--paths", - config.download_path, - "-o", - title, - "--ffmpeg-location", - f"{config.ffmpeg_location}", - ) - - if config.proxy: - sbp_args += ("--proxy", f"{config.proxy}") - - if config.enable_yt_dlp_cookies: - sbp_args += ("--cookies", f"{config.yt_dlp_cookies_path}") - - if config.loudnorm: - sbp_args += ("--postprocessor-args", f"-af {config.loudnorm}") - - sbp_args += (url,) - - cmd = " ".join(sbp_args) - log.info(f"download_playlist: {cmd}") - download_proc = await asyncio.create_subprocess_exec(*sbp_args) - return download_proc - - -# 下载一首歌曲 -async def download_one_music(config, url, name=""): - title = "%(title)s.%(ext)s" - if name: - title = f"{name}.%(ext)s" - sbp_args = ( - "yt-dlp", - "--no-playlist", - "-x", - "--audio-format", - "mp3", - "--audio-quality", - "0", - "--paths", - config.download_path, - "-o", - title, - "--ffmpeg-location", - f"{config.ffmpeg_location}", - ) - - if config.proxy: - sbp_args += ("--proxy", f"{config.proxy}") - - if config.enable_yt_dlp_cookies: - sbp_args += ("--cookies", f"{config.yt_dlp_cookies_path}") - - if config.loudnorm: - sbp_args += ("--postprocessor-args", f"-af {config.loudnorm}") - - sbp_args += (url,) - - cmd = " ".join(sbp_args) - log.info(f"download_one_music: {cmd}") - download_proc = await asyncio.create_subprocess_exec(*sbp_args) - return download_proc - - -def _longest_common_prefix(file_names): - if not file_names: - return "" - - # 将第一个文件名作为初始前缀 - prefix = file_names[0] - - for file_name in file_names[1:]: - while not file_name.startswith(prefix): - # 如果当前文件名不以prefix开头,则缩短prefix - prefix = prefix[:-1] - if not prefix: - return "" - - return prefix - - -def safe_join_path(safe_root, directory): - directory = os.path.join(safe_root, directory) - # Normalize the directory path - normalized_directory = os.path.normpath(directory) - # Ensure the directory is within the safe root - if not normalized_directory.startswith(os.path.normpath(safe_root)): - raise ValueError(f"Access to directory '{directory}' is not allowed.") - return normalized_directory - - -# 移除目录下文件名前缀相同的 -def remove_common_prefix(directory): - files = os.listdir(directory) - - # 获取所有文件的前缀 - common_prefix = _longest_common_prefix(files) - - log.info(f'Common prefix identified: "{common_prefix}"') - - pattern = re.compile(r"^[pP]?(\d+)\s+\d*(.+?)\.(.*$)") - for filename in files: - if filename == common_prefix: - continue - # 检查文件名是否以共同前缀开头 - if filename.startswith(common_prefix): - # 构造新的文件名 - new_filename = filename[len(common_prefix) :] - match = pattern.search(new_filename.strip()) - if match: - num = match.group(1) - name = match.group(2).replace(".", " ").strip() - suffix = match.group(3) - new_filename = f"{num}.{name}.{suffix}" - # 生成完整的文件路径 - old_file_path = os.path.join(directory, filename) - new_file_path = os.path.join(directory, new_filename) - - # 重命名文件 - os.rename(old_file_path, new_file_path) - log.debug(f'Renamed: "{filename}" to "{new_filename}"') - - -def try_add_access_control_param(config, url): - if config.disable_httpauth: - return url - - url_parts = urllib.parse.urlparse(url) - file_path = urllib.parse.unquote(url_parts.path) - correct_code = hashlib.sha256( - (file_path + config.httpauth_username + config.httpauth_password).encode( - "utf-8" - ) - ).hexdigest() - log.debug(f"rewrite url: [{file_path}, {correct_code}]") - - # make new url - parsed_get_args = dict(urllib.parse.parse_qsl(url_parts.query)) - parsed_get_args.update({"code": correct_code}) - encoded_get_args = urllib.parse.urlencode(parsed_get_args, doseq=True) - new_url = urllib.parse.ParseResult( - url_parts.scheme, - url_parts.netloc, - url_parts.path, - url_parts.params, - encoded_get_args, - url_parts.fragment, - ).geturl() - - return new_url - - -# 判断文件在不在排除目录列表 -def not_in_dirs(filename, ignore_absolute_dirs): - file_absolute_path = os.path.abspath(filename) - file_dir = os.path.dirname(file_absolute_path) - for ignore_dir in ignore_absolute_dirs: - if file_dir.startswith(ignore_dir): - log.info(f"{file_dir} in {ignore_dir}") - return False # 文件在排除目录中 - - return True # 文件不在排除目录中 - - -def is_docker(): - return os.path.exists("/app/.dockerenv") - - -async def restart_xiaomusic(): - # 重启 xiaomusic 程序 - sbp_args = ( - "supervisorctl", - "restart", - "xiaomusic", - ) - - cmd = " ".join(sbp_args) - log.info(f"restart_xiaomusic: {cmd}") - await asyncio.sleep(2) - proc = await asyncio.create_subprocess_exec(*sbp_args) - exit_code = await proc.wait() # 等待子进程完成 - log.info(f"restart_xiaomusic completed with exit code {exit_code}") - return exit_code - - -async def update_version(version: str, lite: bool = True): - if not is_docker(): - ret = "xiaomusic 更新只能在 docker 中进行" - log.info(ret) - return ret - lite_tag = "" - if lite: - lite_tag = "-lite" - arch = get_os_architecture() - if "unknown" in arch: - log.warning(f"update_version failed: {arch}") - return arch - # https://github.com/hanxi/xiaomusic/releases/download/main/app-amd64-lite.tar.gz - url = f"https://gproxy.hanxi.cc/proxy/hanxi/xiaomusic/releases/download/{version}/app-{arch}{lite_tag}.tar.gz" - target_directory = "/app" - return await download_and_extract(url, target_directory) - - -def get_os_architecture(): - """ - 获取操作系统架构类型:amd64、arm64、arm-v7。 - - Returns: - str: 架构类型 - """ - arch = platform.machine().lower() - - if arch in ("x86_64", "amd64"): - return "amd64" - elif arch in ("aarch64", "arm64"): - return "arm64" - elif "arm" in arch or "armv7" in arch: - return "arm-v7" - else: - return f"unknown architecture: {arch}" - - -async def download_and_extract(url: str, target_directory: str): - ret = "OK" - # 创建目标目录 - os.makedirs(target_directory, exist_ok=True) - - # 使用 aiohttp 异步下载文件 - async with aiohttp.ClientSession() as session: - async with session.get(url) as response: - if response.status == 200: - file_name = os.path.join(target_directory, url.split("/")[-1]) - file_name = os.path.normpath(file_name) - if not file_name.startswith(target_directory): - log.warning(f"Invalid file path: {file_name}") - return - with open(file_name, "wb") as f: - # 以块的方式下载文件,防止内存占用过大 - async for chunk in response.content.iter_any(): - f.write(chunk) - log.info(f"文件下载完成: {file_name}") - - # 解压下载的文件 - if file_name.endswith(".tar.gz"): - await extract_tar_gz(file_name, target_directory) - else: - ret = f"下载失败, 包有问题: {file_name}" - log.warning(ret) - - else: - ret = f"下载失败, 状态码: {response.status}" - log.warning(ret) - return ret - - -async def extract_tar_gz(file_name: str, target_directory: str): - # 使用 asyncio.create_subprocess_exec 执行 tar 解压命令 - command = ["tar", "-xzvf", file_name, "-C", target_directory] - # 启动子进程执行解压命令 - await asyncio.create_subprocess_exec(*command) - # 不等待子进程完成 - log.info(f"extract_tar_gz ing {file_name}") - - -def chmodfile(file_path: str): - try: - os.chmod(file_path, 0o775) - except Exception as e: - log.info(f"chmodfile failed: {e}") - - -def chmoddir(dir_path: str): - # 获取指定目录下的所有文件和子目录 - for item in os.listdir(dir_path): - item_path = os.path.join(dir_path, item) - # 确保是文件,且不是目录 - if os.path.isfile(item_path): - try: - os.chmod(item_path, 0o775) - log.info(f"Changed permissions of file: {item_path}") - except Exception as e: - log.info(f"chmoddir failed: {e}") - - -async def fetch_json_get(url, headers, config): - connector = None - proxy = None - if config and config.proxy: - connector = aiohttp.TCPConnector( - ssl=False, # 如需验证SSL证书,可改为True(需确保代理支持) - limit=10, - ) - proxy = config.proxy - try: - # 2. 传入代理配置创建ClientSession - async with aiohttp.ClientSession(connector=connector) as session: - # 3. 发起带代理的GET请求 - async with session.get( - url, - headers=headers, - proxy=proxy, # 传入格式化后的代理参数 - timeout=10, # 超时时间(秒),避免无限等待 - ) as response: - if response.status == 200: - data = await response.json() - log.info(f"fetch_json_get: {url} success {data}") - - # 确保返回结果为dict - if isinstance(data, dict): - return data - else: - log.warning(f"Expected dict, but got {type(data)}: {data}") - return {} - else: - log.error(f"HTTP Error: {response.status} {url}") - return {} - except aiohttp.ClientError as e: - log.error(f"ClientError fetching {url} (proxy: {proxy}): {e}") - return {} - except asyncio.TimeoutError: - log.error(f"Timeout fetching {url} (proxy: {proxy})") - return {} - except Exception as e: - log.error(f"Unexpected error fetching {url} (proxy: {proxy}): {e}") - return {} - finally: - # 4. 关闭连接器(避免资源泄漏) - if connector and not connector.closed: - await connector.close() - - -class LRUCache(OrderedDict): - def __init__(self, max_size=1000): - super().__init__() - self.max_size = max_size - - def __setitem__(self, key, value): - if key in self: - # 移动到末尾(最近使用) - self.move_to_end(key) - super().__setitem__(key, value) - # 如果超出大小限制,删除最早使用的项 - if len(self) > self.max_size: - self.popitem(last=False) - - def __getitem__(self, key): - # 访问时移动到末尾(最近使用) - if key in self: - self.move_to_end(key) - return super().__getitem__(key) - - -class MusicUrlCache: - def __init__(self, default_expire_days=1, max_size=1000): - self.cache = LRUCache(max_size) - self.default_expire_days = default_expire_days - self.log = logging.getLogger(__name__) - - async def get(self, url: str, headers: dict = None, config=None) -> str: - """获取URL(优先从缓存获取,没有则请求API) - - Args: - url: 原始URL - headers: API请求需要的headers - Returns: - str: 真实播放URL - """ - # 先查询缓存 - cached_url = self._get_from_cache(url) - if cached_url: - self.log.info(f"Using cached url: {cached_url}") - return cached_url - - # 缓存未命中,请求API - return await self._fetch_from_api(url, headers, config) - - def _get_from_cache(self, url: str) -> str: - """从缓存中获取URL""" - try: - cached_url, expire_time = self.cache[url] - if time.time() > expire_time: - # 缓存过期,删除 - del self.cache[url] - return "" - return cached_url - except KeyError: - return "" - - async def _fetch_from_api(self, url: str, headers: dict = None, config=None) -> str: - """从API获取真实URL""" - data = await fetch_json_get(url, headers or {}, config) - - if not isinstance(data, dict): - self.log.error(f"Invalid API response format: {data}") - return "" - - real_url = data.get("url") - if not real_url: - self.log.error(f"No url in API response: {data}") - return "" - - # 获取过期时间 - expire_time = self._parse_expire_time(data) - - # 缓存结果 - self._set_cache(url, real_url, expire_time) - self.log.info( - f"Cached url, expire_time: {expire_time}, cache size: {len(self.cache)}" - ) - return real_url - - def _parse_expire_time(self, data: dict) -> float | None: - """解析API返回的过期时间""" - try: - extra = data.get("extra", {}) - expire_info = extra.get("expire", {}) - if expire_info and expire_info.get("canExpire"): - expire_time = expire_info.get("time") - if expire_time: - return float(expire_time) - except Exception as e: - self.log.warning(f"Failed to parse expire time: {e}") - return None - - def _set_cache(self, url: str, real_url: str, expire_time: float = None): - """设置缓存""" - if expire_time is None: - expire_time = time.time() + (self.default_expire_days * 24 * 3600) - self.cache[url] = (real_url, expire_time) - - def clear(self): - """清空缓存""" - self.cache.clear() - - @property - def size(self) -> int: - """当前缓存大小""" - return len(self.cache) - - -async def text_to_mp3( - text: str, save_dir: str, voice: str = "zh-CN-XiaoxiaoNeural" -) -> str: - """ - 使用edge-tts将文本转换为MP3语音文件 - - 参数: - text: 需要转换的文本内容 - save_dir: 保存MP3文件的目录路径 - voice: 语音模型(默认中文晓晓) - - 返回: - str: 生成的MP3文件完整路径 - """ - # 确保保存目录存在 - Path(save_dir).mkdir(parents=True, exist_ok=True) - - # 基于文本和语音模型生成唯一文件名(避免相同文本不同语音重复) - content = f"{text}_{voice}".encode() - file_hash = hashlib.md5(content).hexdigest() - mp3_filename = f"{file_hash}.mp3" - mp3_path = os.path.join(save_dir, mp3_filename) - - # 文件已存在直接返回路径 - if os.path.exists(mp3_path): - return mp3_path - - # 调用edge-tts生成语音 - try: - communicate = edge_tts.Communicate(text, voice) - await communicate.save(mp3_path) - log.info(f"语音文件生成成功: {mp3_path}") - except Exception as e: - raise RuntimeError(f"生成语音文件失败: {e}") from e - - return mp3_path diff --git a/xiaomusic/utils/__init__.py b/xiaomusic/utils/__init__.py new file mode 100644 index 0000000..db8e05f --- /dev/null +++ b/xiaomusic/utils/__init__.py @@ -0,0 +1,125 @@ +#!/usr/bin/env python3 +""" +Utils package - 工具函数模块 + +将原 utils.py 拆分为多个职责清晰的子模块: +- text_utils: 文本处理和搜索 +- file_utils: 文件和目录操作 +- music_utils: 音乐文件处理 +- network_utils: 网络请求和下载 +- system_utils: 系统操作和环境 +""" + +# 从各子模块导入常用函数,保持向后兼容 +from xiaomusic.utils.file_utils import ( + chmoddir, + chmodfile, + not_in_dirs, + remove_common_prefix, + safe_join_path, + traverse_music_directory, +) +from xiaomusic.utils.music_utils import ( + Metadata, + convert_file_to_mp3, + extract_audio_metadata, + get_duration_by_ffprobe, + get_duration_by_mutagen, + get_local_music_duration, + get_web_music_duration, + is_m4a, + is_mp3, + remove_id3_tags, + save_picture_by_base64, + set_music_tag_to_file, +) +from xiaomusic.utils.network_utils import ( + MusicUrlCache, + check_bili_fav_list, + download_one_music, + download_playlist, + downloadfile, + fetch_json_get, + text_to_mp3, +) +from xiaomusic.utils.system_utils import ( + deepcopy_data_no_sensitive_info, + download_and_extract, + get_latest_version, + get_os_architecture, + get_random, + is_docker, + parse_cookie_string, + restart_xiaomusic, + try_add_access_control_param, + update_version, + validate_proxy, +) +from xiaomusic.utils.text_utils import ( + calculate_tts_elapse, + chinese_to_number, + custom_sort_key, + find_best_match, + find_key_by_partial_string, + fuzzyfinder, + keyword_detection, + list2str, + parse_str_to_dict, + split_sentences, + traditional_to_simple, +) + +__all__ = [ + # text_utils + "calculate_tts_elapse", + "chinese_to_number", + "custom_sort_key", + "find_best_match", + "find_key_by_partial_string", + "fuzzyfinder", + "keyword_detection", + "list2str", + "parse_str_to_dict", + "split_sentences", + "traditional_to_simple", + # file_utils + "chmoddir", + "chmodfile", + "not_in_dirs", + "remove_common_prefix", + "safe_join_path", + "traverse_music_directory", + # music_utils + "Metadata", + "convert_file_to_mp3", + "extract_audio_metadata", + "get_duration_by_ffprobe", + "get_duration_by_mutagen", + "get_local_music_duration", + "get_web_music_duration", + "is_m4a", + "is_mp3", + "remove_id3_tags", + "set_music_tag_to_file", + "save_picture_by_base64", + # network_utils + "MusicUrlCache", + "check_bili_fav_list", + "download_one_music", + "download_playlist", + "downloadfile", + "fetch_json_get", + "text_to_mp3", + # system_utils + "deepcopy_data_no_sensitive_info", + "download_and_extract", + "get_latest_version", + "get_os_architecture", + "get_random", + "is_docker", + "parse_cookie_string", + "restart_xiaomusic", + "try_add_access_control_param", + "update_version", + "validate_proxy", +] diff --git a/xiaomusic/utils/file_utils.py b/xiaomusic/utils/file_utils.py new file mode 100644 index 0000000..8a5855c --- /dev/null +++ b/xiaomusic/utils/file_utils.py @@ -0,0 +1,189 @@ +#!/usr/bin/env python3 +"""文件和目录操作相关工具函数""" + +import logging +import os +import re + +log = logging.getLogger(__package__) + + +def _get_depth_path(root: str, directory: str, depth: int) -> str: + """计算指定深度的路径""" + # 计算当前目录的深度 + relative_path = root[len(directory) :].strip(os.sep) + path_parts = relative_path.split(os.sep) + if len(path_parts) >= depth: + return os.path.join(directory, *path_parts[:depth]) + else: + return root + + +def _append_files_result( + result: dict, root: str, joinpath: str, files: list, support_extension: set +) -> None: + """将文件添加到结果字典中""" + dir_name = os.path.basename(root) + if dir_name not in result: + result[dir_name] = [] + for file in files: + # 过滤隐藏文件 + if file.startswith("."): + continue + # 过滤文件后缀 + (name, extension) = os.path.splitext(file) + if extension.lower() not in support_extension: + continue + + result[dir_name].append(os.path.join(joinpath, file)) + + +def traverse_music_directory( + directory: str, depth: int, exclude_dirs: set, support_extension: set +) -> dict: + """ + 遍历音乐目录 + + Args: + directory: 目录路径 + depth: 遍历深度 + exclude_dirs: 排除的目录集合 + support_extension: 支持的文件扩展名集合 + + Returns: + {目录名: [文件路径列表]} + """ + result = {} + for root, dirs, files in os.walk(directory, followlinks=True): + # 忽略排除的目录 + dirs[:] = [d for d in dirs if d not in exclude_dirs] + + # 计算当前目录的深度 + current_depth = root[len(directory) :].count(os.sep) + 1 + if current_depth > depth: + depth_path = _get_depth_path(root, directory, depth - 1) + _append_files_result(result, depth_path, root, files, support_extension) + else: + _append_files_result(result, root, root, files, support_extension) + return result + + +def safe_join_path(safe_root: str, directory: str) -> str: + """ + 安全地拼接路径,确保结果在安全根目录内 + + Args: + safe_root: 安全根目录 + directory: 要拼接的目录 + + Returns: + 规范化的完整路径 + + Raises: + ValueError: 如果路径不在安全根目录内 + """ + directory = os.path.join(safe_root, directory) + # Normalize the directory path + normalized_directory = os.path.normpath(directory) + # Ensure the directory is within the safe root + if not normalized_directory.startswith(os.path.normpath(safe_root)): + raise ValueError(f"Access to directory '{directory}' is not allowed.") + return normalized_directory + + +def _longest_common_prefix(file_names: list) -> str: + """查找文件名列表的最长公共前缀""" + if not file_names: + return "" + + # 将第一个文件名作为初始前缀 + prefix = file_names[0] + + for file_name in file_names[1:]: + while not file_name.startswith(prefix): + # 如果当前文件名不以prefix开头,则缩短prefix + prefix = prefix[:-1] + if not prefix: + return "" + + return prefix + + +def remove_common_prefix(directory: str) -> None: + """ + 移除目录下文件名的公共前缀 + + Args: + directory: 目录路径 + """ + files = os.listdir(directory) + + # 获取所有文件的前缀 + common_prefix = _longest_common_prefix(files) + + log.info(f'Common prefix identified: "{common_prefix}"') + + pattern = re.compile(r"^[pP]?(\d+)\s+\d*(.+?)\.(.*$)") + for filename in files: + if filename == common_prefix: + continue + # 检查文件名是否以共同前缀开头 + if filename.startswith(common_prefix): + # 构造新的文件名 + new_filename = filename[len(common_prefix) :] + match = pattern.search(new_filename.strip()) + if match: + num = match.group(1) + name = match.group(2).replace(".", " ").strip() + suffix = match.group(3) + new_filename = f"{num}.{name}.{suffix}" + # 生成完整的文件路径 + old_file_path = os.path.join(directory, filename) + new_file_path = os.path.join(directory, new_filename) + + # 重命名文件 + os.rename(old_file_path, new_file_path) + log.debug(f'Renamed: "{filename}" to "{new_filename}"') + + +def not_in_dirs(filename: str, ignore_absolute_dirs: list) -> bool: + """ + 判断文件是否不在排除目录列表中 + + Args: + filename: 文件路径 + ignore_absolute_dirs: 要忽略的绝对路径列表 + + Returns: + True 如果文件不在排除目录中 + """ + file_absolute_path = os.path.abspath(filename) + file_dir = os.path.dirname(file_absolute_path) + for ignore_dir in ignore_absolute_dirs: + if file_dir.startswith(ignore_dir): + log.info(f"{file_dir} in {ignore_dir}") + return False # 文件在排除目录中 + + return True # 文件不在排除目录中 + + +def chmodfile(file_path: str) -> None: + """修改文件权限为 775""" + try: + os.chmod(file_path, 0o775) + except Exception as e: + log.info(f"chmodfile failed: {e}") + + +def chmoddir(dir_path: str) -> None: + """修改目录下所有文件的权限为 775""" + # 获取指定目录下的所有文件和子目录 + for item in os.listdir(dir_path): + item_path = os.path.join(dir_path, item) + # 确保是文件,且不是目录 + if os.path.isfile(item_path): + try: + os.chmod(item_path, 0o775) + log.info(f"Changed permissions of file: {item_path}") + except Exception as e: + log.info(f"chmoddir failed: {e}") diff --git a/xiaomusic/utils/music_utils.py b/xiaomusic/utils/music_utils.py new file mode 100644 index 0000000..4acc05d --- /dev/null +++ b/xiaomusic/utils/music_utils.py @@ -0,0 +1,695 @@ +#!/usr/bin/env python3 +"""音乐文件处理相关工具函数""" + +import asyncio +import base64 +import hashlib +import io +import json +import logging +import mimetypes +import os +import shutil +import subprocess +import tempfile +from dataclasses import asdict, dataclass +from urllib.parse import urlparse + +import aiohttp +import mutagen +from mutagen.asf import ASF +from mutagen.flac import FLAC +from mutagen.id3 import ( + APIC, + ID3, + TALB, + TCON, + TDRC, + TIT2, + TPE1, + USLT, + Encoding, + TextFrame, + TimeStampTextFrame, +) +from mutagen.mp3 import MP3 +from mutagen.mp4 import MP4 +from mutagen.oggvorbis import OggVorbis +from mutagen.wave import WAVE +from mutagen.wavpack import WavPack +from PIL import Image + +from xiaomusic.const import SUPPORT_MUSIC_TYPE + +log = logging.getLogger(__package__) + + +@dataclass +class Metadata: + """音乐元数据""" + + title: str = "" + artist: str = "" + album: str = "" + year: str = "" + genre: str = "" + picture: str = "" + lyrics: str = "" + + def __init__(self, info=None): + if info: + self.title = info.get("title", "") + self.artist = info.get("artist", "") + self.album = info.get("album", "") + self.year = info.get("year", "") + self.genre = info.get("genre", "") + self.picture = info.get("picture", "") + self.lyrics = info.get("lyrics", "") + + +def is_mp3(url: str) -> bool: + """判断是否为 MP3 文件""" + mt = mimetypes.guess_type(url) + if mt and mt[0] == "audio/mpeg": + return True + return False + + +def is_m4a(url: str) -> bool: + """判断是否为 M4A 文件""" + return url.endswith(".m4a") + + +async def _get_web_music_duration( + session, url: str, config, start: int = 0, end: int = 500 +) -> float: + """ + 异步获取网络音乐文件的部分内容并估算其时长 + + 通过请求 URL 的前几个字节(默认 0-500)下载部分文件, + 写入临时文件后调用本地工具(如 ffprobe)获取音频时长 + + Args: + session: aiohttp.ClientSession 实例 + url: 音乐文件的 URL 地址 + config: 包含配置信息的对象(如 ffmpeg 路径) + start: 请求的起始字节位置 + end: 请求的结束字节位置 + + Returns: + 返回音频的持续时间(秒),如果失败则返回 0 + """ + duration = 0 + # 设置请求头 Range,只请求部分内容(用于快速获取元数据) + headers = {"Range": f"bytes={start}-{end}"} + + # 使用 aiohttp 异步发起 GET 请求,获取部分音频内容 + async with session.get(url, headers=headers) as response: + array_buffer = await response.read() # 读取响应的二进制内容 + + # 创建一个命名的临时文件,并禁用自动删除(以便后续读取) + with tempfile.NamedTemporaryFile(delete=False) as tmp: + tmp.write(array_buffer) # 将下载的部分内容写入临时文件 + tmp_path = tmp.name # 获取该临时文件的真实路径 + + try: + # 调用 get_local_music_duration 并传入文件路径,而不是文件对象 + duration = await get_local_music_duration(tmp_path, config) + except Exception as e: + log.error(f"Error _get_web_music_duration: {e}") + finally: + # 手动删除临时文件,避免残留 + os.unlink(tmp_path) + + return duration + + +async def get_web_music_duration(url: str, config) -> tuple[float, str]: + """ + 获取网络音乐时长 + + Args: + url: 音乐 URL + config: 配置对象 + + Returns: + (时长(秒), 最终URL) + """ + duration = 0 + try: + parsed_url = urlparse(url) + file_path = parsed_url.path + _, extension = os.path.splitext(file_path) + if extension.lower() not in SUPPORT_MUSIC_TYPE: + cleaned_url = parsed_url.geturl() + async with aiohttp.ClientSession() as session: + async with session.get( + cleaned_url, + allow_redirects=True, + headers={ + "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_10_1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/39.0.2171.95 Safari/537.36" + }, + ) as response: + url = str(response.url) + # 设置总超时时间为3秒 + timeout = aiohttp.ClientTimeout(total=3) + async with aiohttp.ClientSession(timeout=timeout) as session: + duration = await _get_web_music_duration( + session, url, config, start=0, end=500 + ) + if duration <= 0: + duration = await _get_web_music_duration( + session, url, config, start=0, end=3000 + ) + except Exception as e: + log.error(f"Error get_web_music_duration: {e}") + return duration, url + + +async def get_local_music_duration(filename: str, config) -> float: + """ + 获取本地音乐文件播放时长 + + Args: + filename: 文件路径 + config: 配置对象 + + Returns: + 时长(秒) + """ + duration = 0 + if config.get_duration_type == "ffprobe": + duration = get_duration_by_ffprobe(filename, config.ffmpeg_location) + else: + duration = await get_duration_by_mutagen(filename) + + # 换个方式重试一次 + if duration == 0: + if config.get_duration_type != "ffprobe": + duration = get_duration_by_ffprobe(filename, config.ffmpeg_location) + else: + duration = await get_duration_by_mutagen(filename) + + return duration + + +async def get_duration_by_mutagen(file_path: str) -> float: + """使用 mutagen 获取音乐时长""" + duration = 0 + try: + loop = asyncio.get_event_loop() + if is_mp3(file_path): + m = await loop.run_in_executor(None, mutagen.mp3.MP3, file_path) + else: + m = await loop.run_in_executor(None, mutagen.File, file_path) + duration = m.info.length + except Exception as e: + log.warning(f"Error getting local music {file_path} duration: {e}") + return duration + + +def get_duration_by_ffprobe(file_path: str, ffmpeg_location: str) -> float: + """使用 ffprobe 获取音乐时长""" + duration = 0 + try: + # 构造 ffprobe 命令参数 + cmd_args = [ + os.path.join(ffmpeg_location, "ffprobe"), + "-v", + "error", # 只输出错误信息,避免混杂在其他输出中 + "-show_entries", + "format=duration", # 仅显示时长 + "-of", + "json", # 以 JSON 格式输出 + file_path, + ] + + # 输出待执行的完整命令 + full_command = " ".join(cmd_args) + log.info(f"待执行的完整命令 ffprobe command: {full_command}") + + # 使用 ffprobe 获取文件的元数据,并以 JSON 格式输出 + result = subprocess.run( + cmd_args, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + ) + + # 输出命令执行结果 + log.info( + f"命令执行结果 command result - return code: {result.returncode}, stdout: {result.stdout}" + ) + + # 解析 JSON 输出 + ffprobe_output = json.loads(result.stdout) + + # 获取时长 + duration = float(ffprobe_output["format"]["duration"]) + log.info( + f"Successfully extracted duration: {duration} seconds for file: {file_path}" + ) + + except Exception as e: + log.warning(f"Error getting local music {file_path} duration: {e}") + return duration + + +def no_padding(info) -> int: + """移除 MP3 文件的 padding""" + # this will remove all padding + return 0 + + +def remove_id3_tags(input_file: str, config) -> str: + """ + 移除 MP3 文件的 ID3 标签以减少延迟 + + Args: + input_file: 输入文件路径 + config: 配置对象 + + Returns: + 处理后的相对路径,如果无需处理则返回 None + """ + audio = MP3(input_file, ID3=ID3) + + # 检查是否存在ID3 v2.3或v2.4标签 + if not ( + audio.tags + and (audio.tags.version == (2, 3, 0) or audio.tags.version == (2, 4, 0)) + ): + return None + + music_path = config.music_path + temp_dir = config.temp_dir + + # 构造新文件的路径 + out_file_name = os.path.splitext(os.path.basename(input_file))[0] + out_file_path = os.path.join(temp_dir, f"{out_file_name}.mp3") + relative_path = os.path.relpath(out_file_path, music_path) + + # 路径相同的情况 + input_absolute_path = os.path.abspath(input_file) + output_absolute_path = os.path.abspath(out_file_path) + if input_absolute_path == output_absolute_path: + log.info(f"File {input_file} = {out_file_path} . Skipping remove_id3_tags.") + return None + + # 检查目标文件是否存在 + if os.path.exists(out_file_path): + log.info(f"File {out_file_path} already exists. Skipping remove_id3_tags.") + return relative_path + + # 开始去除(不再需要检查) + # 拷贝文件 + shutil.copy(input_file, out_file_path) + outaudio = MP3(out_file_path, ID3=ID3) + # 删除ID3标签 + outaudio.delete() + # 保存修改后的文件 + outaudio.save(padding=no_padding) + log.info(f"File {out_file_path} remove_id3_tags ok.") + return relative_path + + +def convert_file_to_mp3(input_file: str, config) -> str: + """ + 转换音频文件为 MP3 格式 + + Args: + input_file: 输入文件路径 + config: 配置对象 + + Returns: + 转换后的相对路径,如果无需转换则返回 None + """ + music_path = config.music_path + temp_dir = config.temp_dir + + out_file_name = os.path.splitext(os.path.basename(input_file))[0] + out_file_path = os.path.join(temp_dir, f"{out_file_name}.mp3") + relative_path = os.path.relpath(out_file_path, music_path) + + # 路径相同的情况 + input_absolute_path = os.path.abspath(input_file) + output_absolute_path = os.path.abspath(out_file_path) + if input_absolute_path == output_absolute_path: + log.info(f"File {input_file} = {out_file_path} . Skipping convert_file_to_mp3.") + return None + + absolute_music_path = os.path.abspath(music_path) + if not input_absolute_path.startswith(absolute_music_path): + log.error(f"Invalid input file path: {input_file}") + return None + + # 检查目标文件是否存在 + if os.path.exists(out_file_path): + log.info(f"File {out_file_path} already exists. Skipping convert_file_to_mp3.") + return relative_path + + # 检查是否存在 loudnorm 参数 + loudnorm_args = [] + if config.loudnorm: + loudnorm_args = ["-af", config.loudnorm] + + command = [ + os.path.join(config.ffmpeg_location, "ffmpeg"), + "-i", + input_absolute_path, + "-f", + "mp3", + "-vn", + "-y", + *loudnorm_args, + out_file_path, + ] + + try: + subprocess.run(command, check=True) + except subprocess.CalledProcessError as e: + log.exception(f"Error during conversion: {e}") + return None + + log.info(f"File {input_file} to {out_file_path} convert_file_to_mp3 ok.") + return relative_path + + +def _to_utf8(v): + """转换标签值为 UTF-8 字符串""" + if isinstance(v, TextFrame) and not isinstance(v, TimeStampTextFrame): + old_ts = "".join(v.text) + if v.encoding == Encoding.LATIN1: + bs = old_ts.encode("latin1") + ts = bs.decode("GBK", errors="ignore") + return ts + return old_ts + elif isinstance(v, list): + return "".join(str(item) for item in v) + return str(v) + + +def _get_tag_value(tags, k: str) -> str: + """获取标签值""" + if k not in tags: + return "" + v = tags[k] + return _to_utf8(v) + + +def _get_alltag_value(tags, k: str) -> str: + """获取所有标签值""" + v = tags.getall(k) + if len(v) > 0: + return _to_utf8(v[0]) + return "" + + +def _save_picture(picture_data: bytes, save_root: str, file_path: str) -> str: + """保存图片""" + # 计算文件名的哈希值 + file_hash = hashlib.md5(file_path.encode("utf-8")).hexdigest() + # 创建目录结构 + dir_path = os.path.join(save_root, file_hash[-6:]) + os.makedirs(dir_path, exist_ok=True) + + # 保存图片 + filename = os.path.basename(file_path) + (name, _) = os.path.splitext(filename) + picture_path = os.path.join(dir_path, f"{name}.jpg") + + try: + _resize_save_image(picture_data, picture_path) + except Exception as e: + log.warning(f"Error _resize_save_image: {e}") + return picture_path + + +def _resize_save_image(image_bytes: bytes, save_path: str, max_size: int = 300) -> str: + """缩放并保存图片""" + # 将 bytes 转换为 PIL Image 对象 + image = None + try: + image = Image.open(io.BytesIO(image_bytes)) + image = image.convert("RGB") + except Exception as e: + log.warning(f"Error _resize_save_image: {e}") + return None + + # 获取原始尺寸 + original_width, original_height = image.size + + # 如果图片的宽度和高度都小于 max_size,则直接保存原始图片 + if original_width <= max_size and original_height <= max_size: + image.save(save_path, format="JPEG") + return save_path + + # 计算缩放比例,保持等比缩放 + scaling_factor = min(max_size / original_width, max_size / original_height) + + # 计算新的尺寸 + new_width = int(original_width * scaling_factor) + new_height = int(original_height * scaling_factor) + + resized_image = image.resize((new_width, new_height), Image.Resampling.LANCZOS) + resized_image.save(save_path, format="JPEG") + return save_path + + +def save_picture_by_base64( + picture_base64_data: str, save_root: str, file_path: str +) -> str: + """通过 base64 数据保存图片""" + try: + picture_data = base64.b64decode(picture_base64_data) + except (TypeError, ValueError) as e: + log.exception(f"Error decoding base64 data: {e}") + return None + return _save_picture(picture_data, save_root, file_path) + + +def extract_audio_metadata(file_path: str, save_root: str) -> dict: + """ + 提取音频文件的元数据 + + Args: + file_path: 音频文件路径 + save_root: 图片保存根目录 + + Returns: + 元数据字典 + """ + metadata = Metadata() + + audio = None + try: + audio = mutagen.File(file_path) + except Exception as e: + log.warning(f"Error extract_audio_metadata file: {file_path} {e}") + + if audio is None: + return asdict(metadata) + + tags = audio.tags + if tags is None: + return asdict(metadata) + + if isinstance(audio, MP3): + metadata.title = _get_tag_value(tags, "TIT2") + metadata.artist = _get_tag_value(tags, "TPE1") + metadata.album = _get_tag_value(tags, "TALB") + metadata.year = _get_tag_value(tags, "TDRC") + metadata.genre = _get_tag_value(tags, "TCON") + metadata.lyrics = _get_alltag_value(tags, "USLT") + for tag in tags.values(): + if isinstance(tag, APIC): + metadata.picture = _save_picture(tag.data, save_root, file_path) + break + + elif isinstance(audio, FLAC): + metadata.title = _get_tag_value(tags, "TITLE") + metadata.artist = _get_tag_value(tags, "ARTIST") + metadata.album = _get_tag_value(tags, "ALBUM") + metadata.year = _get_tag_value(tags, "DATE") + metadata.genre = _get_tag_value(tags, "GENRE") + if audio.pictures: + metadata.picture = _save_picture( + audio.pictures[0].data, save_root, file_path + ) + if "lyrics" in audio: + metadata.lyrics = audio["lyrics"][0] + + elif isinstance(audio, MP4): + metadata.title = _get_tag_value(tags, "\xa9nam") + metadata.artist = _get_tag_value(tags, "\xa9ART") + metadata.album = _get_tag_value(tags, "\xa9alb") + metadata.year = _get_tag_value(tags, "\xa9day") + metadata.genre = _get_tag_value(tags, "\xa9gen") + if "covr" in tags: + metadata.picture = _save_picture(tags["covr"][0], save_root, file_path) + + elif isinstance(audio, OggVorbis): + metadata.title = _get_tag_value(tags, "TITLE") + metadata.artist = _get_tag_value(tags, "ARTIST") + metadata.album = _get_tag_value(tags, "ALBUM") + metadata.year = _get_tag_value(tags, "DATE") + metadata.genre = _get_tag_value(tags, "GENRE") + if "metadata_block_picture" in tags: + picture = json.loads(base64.b64decode(tags["metadata_block_picture"][0])) + metadata.picture = _save_picture( + base64.b64decode(picture["data"]), save_root, file_path + ) + + elif isinstance(audio, ASF): + metadata.title = _get_tag_value(tags, "Title") + metadata.artist = _get_tag_value(tags, "Author") + metadata.album = _get_tag_value(tags, "WM/AlbumTitle") + metadata.year = _get_tag_value(tags, "WM/Year") + metadata.genre = _get_tag_value(tags, "WM/Genre") + if "WM/Picture" in tags: + metadata.picture = _save_picture( + tags["WM/Picture"][0].value, save_root, file_path + ) + + elif isinstance(audio, WavPack): + metadata.title = _get_tag_value(tags, "Title") + metadata.artist = _get_tag_value(tags, "Artist") + metadata.album = _get_tag_value(tags, "Album") + metadata.year = _get_tag_value(tags, "Year") + metadata.genre = _get_tag_value(tags, "Genre") + if audio.pictures: + metadata.picture = _save_picture( + audio.pictures[0].data, save_root, file_path + ) + + elif isinstance(audio, WAVE): + metadata.title = _get_tag_value(tags, "Title") + metadata.artist = _get_tag_value(tags, "Artist") + + return asdict(metadata) + + +def _set_mp3_tags(audio, info: Metadata) -> None: + """设置 MP3 标签""" + audio.tags = ID3() + audio["TIT2"] = TIT2(encoding=3, text=info.title) + audio["TPE1"] = TPE1(encoding=3, text=info.artist) + audio["TALB"] = TALB(encoding=3, text=info.album) + audio["TDRC"] = TDRC(encoding=3, text=info.year) + audio["TCON"] = TCON(encoding=3, text=info.genre) + + # 使用 USLT 存储歌词 + if info.lyrics: + audio["USLT"] = USLT(encoding=3, lang="eng", text=info.lyrics) + + # 添加封面图片 + if info.picture: + with open(info.picture, "rb") as img_file: + image_data = img_file.read() + audio["APIC"] = APIC( + encoding=3, mime="image/jpeg", type=3, desc="Cover", data=image_data + ) + audio.save() # 保存修改 + + +def _set_flac_tags(audio, info: Metadata) -> None: + """设置 FLAC 标签""" + audio["TITLE"] = info.title + audio["ARTIST"] = info.artist + audio["ALBUM"] = info.album + audio["DATE"] = info.year + audio["GENRE"] = info.genre + if info.lyrics: + audio["LYRICS"] = info.lyrics + if info.picture: + with open(info.picture, "rb") as img_file: + image_data = img_file.read() + audio.add_picture(image_data) + + +def _set_mp4_tags(audio, info: Metadata) -> None: + """设置 MP4 标签""" + audio["nam"] = info.title + audio["ART"] = info.artist + audio["alb"] = info.album + audio["day"] = info.year + audio["gen"] = info.genre + if info.picture: + with open(info.picture, "rb") as img_file: + image_data = img_file.read() + audio["covr"] = [image_data] + + +def _set_ogg_tags(audio, info: Metadata) -> None: + """设置 OGG 标签""" + audio["TITLE"] = info.title + audio["ARTIST"] = info.artist + audio["ALBUM"] = info.album + audio["DATE"] = info.year + audio["GENRE"] = info.genre + if info.lyrics: + audio["LYRICS"] = info.lyrics + if info.picture: + with open(info.picture, "rb") as img_file: + image_data = img_file.read() + audio["metadata_block_picture"] = base64.b64encode(image_data).decode() + + +def _set_asf_tags(audio, info: Metadata) -> None: + """设置 ASF 标签""" + audio["Title"] = info.title + audio["Author"] = info.artist + audio["WM/AlbumTitle"] = info.album + audio["WM/Year"] = info.year + audio["WM/Genre"] = info.genre + if info.picture: + with open(info.picture, "rb") as img_file: + image_data = img_file.read() + audio["WM/Picture"] = image_data + + +def _set_wave_tags(audio, info: Metadata) -> None: + """设置 WAVE 标签""" + audio["Title"] = info.title + audio["Artist"] = info.artist + + +def set_music_tag_to_file(file_path: str, info: Metadata) -> str: + """ + 设置音乐文件的标签信息 + + Args: + file_path: 文件路径 + info: 元数据对象 + + Returns: + "OK" 或错误信息 + """ + audio = mutagen.File(file_path, easy=True) + if audio is None: + log.error(f"Unable to open file {file_path}") + return "Unable to open file" + + if isinstance(audio, MP3): + _set_mp3_tags(audio, info) + elif isinstance(audio, FLAC): + _set_flac_tags(audio, info) + elif isinstance(audio, MP4): + _set_mp4_tags(audio, info) + elif isinstance(audio, OggVorbis): + _set_ogg_tags(audio, info) + elif isinstance(audio, ASF): + _set_asf_tags(audio, info) + elif isinstance(audio, WAVE): + _set_wave_tags(audio, info) + else: + log.error(f"Unsupported file type for {file_path}") + return "Unsupported file type" + + try: + audio.save() + log.info(f"Tags saved successfully to {file_path}") + return "OK" + except Exception as e: + log.exception(f"Error saving tags: {e}") + return "Error saving tags" diff --git a/xiaomusic/utils/network_utils.py b/xiaomusic/utils/network_utils.py new file mode 100644 index 0000000..b5cb47c --- /dev/null +++ b/xiaomusic/utils/network_utils.py @@ -0,0 +1,448 @@ +#!/usr/bin/env python3 +"""网络请求和下载相关工具函数""" + +import asyncio +import hashlib +import logging +import os +import time +from collections import OrderedDict +from pathlib import Path +from time import sleep +from urllib.parse import parse_qs, urlparse + +import aiohttp +import edge_tts + +log = logging.getLogger(__package__) + + +async def downloadfile(url: str) -> str: + """ + 下载文件内容 + + Args: + url: 文件 URL + + Returns: + 文件文本内容 + + Raises: + Warning: 如果 URL 协议不是 HTTP/HTTPS + """ + # 清理和验证URL + # 解析URL + parsed_url = urlparse(url) + # 基础验证:仅允许HTTP和HTTPS协议 + if parsed_url.scheme not in ("http", "https"): + raise Warning( + f"Invalid URL scheme: {parsed_url.scheme}. Only HTTP and HTTPS are allowed." + ) + # 构建目标URL + cleaned_url = parsed_url.geturl() + + # 使用 aiohttp 创建一个客户端会话来发起请求 + async with aiohttp.ClientSession() as session: + async with session.get( + cleaned_url, timeout=5 + ) as response: # 增加超时以避免长时间挂起 + # 如果响应不是200,引发异常 + response.raise_for_status() + # 读取响应文本 + text = await response.text() + return text + + +async def check_bili_fav_list(url: str) -> dict: + """ + 检查 B 站收藏夹/合集 + + Args: + url: B站收藏夹或合集 URL + + Returns: + {bvid/url: title} 字典 + + Raises: + ValueError: 如果不支持的类型 + Exception: 如果请求失败 + """ + bvid_info = {} + parsed_url = urlparse(url) + path = parsed_url.path + # 提取查询参数 + query_params = parse_qs(parsed_url.query) + + if parsed_url.hostname == "space.bilibili.com": + if "/favlist" in path: + lid = query_params.get("fid", [None])[0] + type = query_params.get("ctype", [None])[0] + if type == "11": + type = "create" + elif type == "21": + type = "collect" + else: + raise ValueError("当前只支持合集和收藏夹") + elif "/lists/" in path: + parts = path.split("/") + if len(parts) >= 4 and "?" in url: + lid = parts[3] # 提取 lid + type = query_params.get("type", [None])[0] + + # https://api.bilibili.com/x/polymer/web-space/seasons_archives_list?season_id={lid}&page_size=30&page_num=1 + page_size = 100 + page_num = 1 + headers = { + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0 Safari/537.36", + "Accept": "application/json, text/plain, */*", + "Referer": url, + "Origin": "https://space.bilibili.com", + } + async with aiohttp.ClientSession(headers=headers) as session: + if type == "season" or type == "collect": + while True: + list_url = f"https://api.bilibili.com/x/polymer/web-space/seasons_archives_list?season_id={lid}&page_size={page_size}&page_num={page_num}" + async with session.get(list_url) as response: + if response.status != 200: + raise Exception(f"Failed to fetch data from {list_url}") + data = await response.json() + archives = data.get("data", {}).get("archives", []) + if not archives: + break + for archive in archives: + bvid = archive.get("bvid", None) + title = archive.get("title", None) + bvid_info[bvid] = title + + if len(archives) < page_size: + break + page_num += 1 + sleep(1) + elif type == "create": + while True: + list_url = f"https://api.bilibili.com/x/v3/fav/resource/list?media_id={lid}&pn={page_num}&ps={page_size}&order=mtime" + async with session.get(list_url) as response: + if response.status != 200: + raise Exception(f"Failed to fetch data from {list_url}") + data = await response.json() + medias = data.get("data", {}).get("medias", []) + if not medias: + break + for media in medias: + bvid = media.get("bvid", None) + title = media.get("title", None) + bvurl = f"https://www.bilibili.com/video/{bvid}" + bvid_info[bvurl] = title + + if len(medias) < page_size: + break + page_num += 1 + else: + raise ValueError("当前只支持合集和收藏夹") + return bvid_info + + +async def download_playlist(config, url: str, dirname: str): + """ + 下载播放列表 + + Args: + config: 配置对象 + url: 播放列表 URL + dirname: 保存目录名 + + Returns: + 下载进程对象 + """ + title = f"{dirname}/%(title)s.%(ext)s" + sbp_args = ( + "yt-dlp", + "--yes-playlist", + "-x", + "--audio-format", + "mp3", + "--audio-quality", + "0", + "--paths", + config.download_path, + "-o", + title, + "--ffmpeg-location", + f"{config.ffmpeg_location}", + ) + + if config.proxy: + sbp_args += ("--proxy", f"{config.proxy}") + + if config.enable_yt_dlp_cookies: + sbp_args += ("--cookies", f"{config.yt_dlp_cookies_path}") + + if config.loudnorm: + sbp_args += ("--postprocessor-args", f"-af {config.loudnorm}") + + sbp_args += (url,) + + cmd = " ".join(sbp_args) + log.info(f"download_playlist: {cmd}") + download_proc = await asyncio.create_subprocess_exec(*sbp_args) + return download_proc + + +async def download_one_music(config, url: str, name: str = ""): + """ + 下载单首歌曲 + + Args: + config: 配置对象 + url: 歌曲 URL + name: 文件名(可选) + + Returns: + 下载进程对象 + """ + title = "%(title)s.%(ext)s" + if name: + title = f"{name}.%(ext)s" + sbp_args = ( + "yt-dlp", + "--no-playlist", + "-x", + "--audio-format", + "mp3", + "--audio-quality", + "0", + "--paths", + config.download_path, + "-o", + title, + "--ffmpeg-location", + f"{config.ffmpeg_location}", + ) + + if config.proxy: + sbp_args += ("--proxy", f"{config.proxy}") + + if config.enable_yt_dlp_cookies: + sbp_args += ("--cookies", f"{config.yt_dlp_cookies_path}") + + if config.loudnorm: + sbp_args += ("--postprocessor-args", f"-af {config.loudnorm}") + + sbp_args += (url,) + + cmd = " ".join(sbp_args) + log.info(f"download_one_music: {cmd}") + download_proc = await asyncio.create_subprocess_exec(*sbp_args) + return download_proc + + +async def fetch_json_get(url: str, headers: dict, config) -> dict: + """ + 发起 GET 请求获取 JSON 数据 + + Args: + url: 请求 URL + headers: 请求头 + config: 配置对象(用于代理设置) + + Returns: + JSON 响应数据字典 + """ + connector = None + proxy = None + if config and config.proxy: + connector = aiohttp.TCPConnector( + ssl=False, # 如需验证SSL证书,可改为True(需确保代理支持) + limit=10, + ) + proxy = config.proxy + try: + # 2. 传入代理配置创建ClientSession + async with aiohttp.ClientSession(connector=connector) as session: + # 3. 发起带代理的GET请求 + async with session.get( + url, + headers=headers, + proxy=proxy, # 传入格式化后的代理参数 + timeout=10, # 超时时间(秒),避免无限等待 + ) as response: + if response.status == 200: + data = await response.json() + log.info(f"fetch_json_get: {url} success {data}") + + # 确保返回结果为dict + if isinstance(data, dict): + return data + else: + log.warning(f"Expected dict, but got {type(data)}: {data}") + return {} + else: + log.error(f"HTTP Error: {response.status} {url}") + return {} + except aiohttp.ClientError as e: + log.error(f"ClientError fetching {url} (proxy: {proxy}): {e}") + return {} + except asyncio.TimeoutError: + log.error(f"Timeout fetching {url} (proxy: {proxy})") + return {} + except Exception as e: + log.error(f"Unexpected error fetching {url} (proxy: {proxy}): {e}") + return {} + finally: + # 4. 关闭连接器(避免资源泄漏) + if connector and not connector.closed: + await connector.close() + + +class LRUCache(OrderedDict): + """LRU 缓存实现""" + + def __init__(self, max_size: int = 1000): + super().__init__() + self.max_size = max_size + + def __setitem__(self, key, value): + if key in self: + # 移动到末尾(最近使用) + self.move_to_end(key) + super().__setitem__(key, value) + # 如果超出大小限制,删除最早使用的项 + if len(self) > self.max_size: + self.popitem(last=False) + + def __getitem__(self, key): + # 访问时移动到末尾(最近使用) + if key in self: + self.move_to_end(key) + return super().__getitem__(key) + + +class MusicUrlCache: + """音乐 URL 缓存管理器""" + + def __init__(self, default_expire_days: int = 1, max_size: int = 1000): + self.cache = LRUCache(max_size) + self.default_expire_days = default_expire_days + self.log = logging.getLogger(__name__) + + async def get(self, url: str, headers: dict = None, config=None) -> str: + """ + 获取URL(优先从缓存获取,没有则请求API) + + Args: + url: 原始URL + headers: API请求需要的headers + config: 配置对象 + + Returns: + str: 真实播放URL + """ + # 先查询缓存 + cached_url = self._get_from_cache(url) + if cached_url: + self.log.info(f"Using cached url: {cached_url}") + return cached_url + + # 缓存未命中,请求API + return await self._fetch_from_api(url, headers, config) + + def _get_from_cache(self, url: str) -> str: + """从缓存中获取URL""" + try: + cached_url, expire_time = self.cache[url] + if time.time() > expire_time: + # 缓存过期,删除 + del self.cache[url] + return "" + return cached_url + except KeyError: + return "" + + async def _fetch_from_api(self, url: str, headers: dict = None, config=None) -> str: + """从API获取真实URL""" + data = await fetch_json_get(url, headers or {}, config) + + if not isinstance(data, dict): + self.log.error(f"Invalid API response format: {data}") + return "" + + real_url = data.get("url") + if not real_url: + self.log.error(f"No url in API response: {data}") + return "" + + # 获取过期时间 + expire_time = self._parse_expire_time(data) + + # 缓存结果 + self._set_cache(url, real_url, expire_time) + self.log.info( + f"Cached url, expire_time: {expire_time}, cache size: {len(self.cache)}" + ) + return real_url + + def _parse_expire_time(self, data: dict) -> float | None: + """解析API返回的过期时间""" + try: + extra = data.get("extra", {}) + expire_info = extra.get("expire", {}) + if expire_info and expire_info.get("canExpire"): + expire_time = expire_info.get("time") + if expire_time: + return float(expire_time) + except Exception as e: + self.log.warning(f"Failed to parse expire time: {e}") + return None + + def _set_cache(self, url: str, real_url: str, expire_time: float = None): + """设置缓存""" + if expire_time is None: + expire_time = time.time() + (self.default_expire_days * 24 * 3600) + self.cache[url] = (real_url, expire_time) + + def clear(self): + """清空缓存""" + self.cache.clear() + + @property + def size(self) -> int: + """当前缓存大小""" + return len(self.cache) + + +async def text_to_mp3( + text: str, save_dir: str, voice: str = "zh-CN-XiaoxiaoNeural" +) -> str: + """ + 使用edge-tts将文本转换为MP3语音文件 + + 参数: + text: 需要转换的文本内容 + save_dir: 保存MP3文件的目录路径 + voice: 语音模型(默认中文晓晓) + + 返回: + str: 生成的MP3文件完整路径 + """ + # 确保保存目录存在 + Path(save_dir).mkdir(parents=True, exist_ok=True) + + # 基于文本和语音模型生成唯一文件名(避免相同文本不同语音重复) + content = f"{text}_{voice}".encode() + file_hash = hashlib.md5(content).hexdigest() + mp3_filename = f"{file_hash}.mp3" + mp3_path = os.path.join(save_dir, mp3_filename) + + # 文件已存在直接返回路径 + if os.path.exists(mp3_path): + return mp3_path + + # 调用edge-tts生成语音 + try: + communicate = edge_tts.Communicate(text, voice) + await communicate.save(mp3_path) + log.info(f"语音文件生成成功: {mp3_path}") + except Exception as e: + raise RuntimeError(f"生成语音文件失败: {e}") from e + + return mp3_path diff --git a/xiaomusic/utils/system_utils.py b/xiaomusic/utils/system_utils.py new file mode 100644 index 0000000..49f8b68 --- /dev/null +++ b/xiaomusic/utils/system_utils.py @@ -0,0 +1,299 @@ +#!/usr/bin/env python3 +"""系统操作和环境相关工具函数""" + +import asyncio +import copy +import hashlib +import logging +import os +import platform +import random +import string +import urllib.parse +from http.cookies import SimpleCookie +from urllib.parse import urlparse + +import aiohttp +from requests.utils import cookiejar_from_dict + +log = logging.getLogger(__package__) + + +def parse_cookie_string(cookie_string: str): + """ + 解析 Cookie 字符串 + + Args: + cookie_string: Cookie 字符串 + + Returns: + CookieJar 对象 + """ + cookie = SimpleCookie() + cookie.load(cookie_string) + cookies_dict = {k: m.value for k, m in cookie.items()} + return cookiejar_from_dict(cookies_dict, cookiejar=None, overwrite=True) + + +def validate_proxy(proxy_str: str) -> bool: + """ + 验证代理字符串格式 + + Args: + proxy_str: 代理字符串 + + Returns: + True 如果格式正确 + + Raises: + ValueError: 如果格式不正确 + """ + parsed = urlparse(proxy_str) + if parsed.scheme not in ("http", "https"): + raise ValueError("Proxy scheme must be http or https") + if not (parsed.hostname and parsed.port): + raise ValueError("Proxy hostname and port must be set") + + return True + + +def get_random(length: int) -> str: + """ + 生成随机字符串 + + Args: + length: 字符串长度 + + Returns: + 随机字符串 + """ + return "".join(random.sample(string.ascii_letters + string.digits, length)) + + +def deepcopy_data_no_sensitive_info(data, fields_to_anonymize: list = None): + """ + 深拷贝数据并脱敏 + + Args: + data: 要拷贝的数据(字典或对象) + fields_to_anonymize: 需要脱敏的字段列表 + + Returns: + 脱敏后的深拷贝数据 + """ + if fields_to_anonymize is None: + fields_to_anonymize = [ + "account", + "password", + "httpauth_username", + "httpauth_password", + ] + + copy_data = copy.deepcopy(data) + + # 检查copy_data是否是字典或具有属性的对象 + if isinstance(copy_data, dict): + # 对字典进行处理 + for field in fields_to_anonymize: + if field in copy_data: + copy_data[field] = "******" + else: + # 对对象进行处理 + for field in fields_to_anonymize: + if hasattr(copy_data, field): + setattr(copy_data, field, "******") + + return copy_data + + +def try_add_access_control_param(config, url: str) -> str: + """ + 为 URL 添加访问控制参数 + + Args: + config: 配置对象 + url: 原始 URL + + Returns: + 添加了访问控制参数的 URL + """ + if config.disable_httpauth: + return url + + url_parts = urllib.parse.urlparse(url) + file_path = urllib.parse.unquote(url_parts.path) + correct_code = hashlib.sha256( + (file_path + config.httpauth_username + config.httpauth_password).encode( + "utf-8" + ) + ).hexdigest() + log.debug(f"rewrite url: [{file_path}, {correct_code}]") + + # make new url + parsed_get_args = dict(urllib.parse.parse_qsl(url_parts.query)) + parsed_get_args.update({"code": correct_code}) + encoded_get_args = urllib.parse.urlencode(parsed_get_args, doseq=True) + new_url = urllib.parse.ParseResult( + url_parts.scheme, + url_parts.netloc, + url_parts.path, + url_parts.params, + encoded_get_args, + url_parts.fragment, + ).geturl() + + return new_url + + +def is_docker() -> bool: + """判断是否在 Docker 容器中运行""" + return os.path.exists("/app/.dockerenv") + + +def get_os_architecture() -> str: + """ + 获取操作系统架构类型:amd64、arm64、arm-v7 + + Returns: + str: 架构类型 + """ + arch = platform.machine().lower() + + if arch in ("x86_64", "amd64"): + return "amd64" + elif arch in ("aarch64", "arm64"): + return "arm64" + elif "arm" in arch or "armv7" in arch: + return "arm-v7" + else: + return f"unknown architecture: {arch}" + + +async def get_latest_version(package_name: str) -> str: + """ + 从 PyPI 获取包的最新版本 + + Args: + package_name: 包名 + + Returns: + 最新版本号,失败返回 None + """ + url = f"https://pypi.org/pypi/{package_name}/json" + async with aiohttp.ClientSession() as session: + async with session.get(url) as response: + if response.status == 200: + data = await response.json() + return data["info"]["version"] + else: + return None + + +async def restart_xiaomusic() -> int: + """ + 重启 xiaomusic 程序 + + Returns: + 退出码 + """ + # 重启 xiaomusic 程序 + sbp_args = ( + "supervisorctl", + "restart", + "xiaomusic", + ) + + cmd = " ".join(sbp_args) + log.info(f"restart_xiaomusic: {cmd}") + await asyncio.sleep(2) + proc = await asyncio.create_subprocess_exec(*sbp_args) + exit_code = await proc.wait() # 等待子进程完成 + log.info(f"restart_xiaomusic completed with exit code {exit_code}") + return exit_code + + +async def update_version(version: str, lite: bool = True) -> str: + """ + 更新 xiaomusic 版本 + + Args: + version: 版本号 + lite: 是否使用 lite 版本 + + Returns: + 结果消息 + """ + if not is_docker(): + ret = "xiaomusic 更新只能在 docker 中进行" + log.info(ret) + return ret + lite_tag = "" + if lite: + lite_tag = "-lite" + arch = get_os_architecture() + if "unknown" in arch: + log.warning(f"update_version failed: {arch}") + return arch + # https://github.com/hanxi/xiaomusic/releases/download/main/app-amd64-lite.tar.gz + url = f"https://gproxy.hanxi.cc/proxy/hanxi/xiaomusic/releases/download/{version}/app-{arch}{lite_tag}.tar.gz" + target_directory = "/app" + return await download_and_extract(url, target_directory) + + +async def download_and_extract(url: str, target_directory: str) -> str: + """ + 下载并解压文件 + + Args: + url: 下载 URL + target_directory: 目标目录 + + Returns: + 结果消息 + """ + ret = "OK" + # 创建目标目录 + os.makedirs(target_directory, exist_ok=True) + + # 使用 aiohttp 异步下载文件 + async with aiohttp.ClientSession() as session: + async with session.get(url) as response: + if response.status == 200: + file_name = os.path.join(target_directory, url.split("/")[-1]) + file_name = os.path.normpath(file_name) + if not file_name.startswith(target_directory): + log.warning(f"Invalid file path: {file_name}") + return "Invalid file path" + with open(file_name, "wb") as f: + # 以块的方式下载文件,防止内存占用过大 + async for chunk in response.content.iter_any(): + f.write(chunk) + log.info(f"文件下载完成: {file_name}") + + # 解压下载的文件 + if file_name.endswith(".tar.gz"): + await extract_tar_gz(file_name, target_directory) + else: + ret = f"下载失败, 包有问题: {file_name}" + log.warning(ret) + + else: + ret = f"下载失败, 状态码: {response.status}" + log.warning(ret) + return ret + + +async def extract_tar_gz(file_name: str, target_directory: str) -> None: + """ + 解压 tar.gz 文件 + + Args: + file_name: 文件路径 + target_directory: 目标目录 + """ + # 使用 asyncio.create_subprocess_exec 执行 tar 解压命令 + command = ["tar", "-xzvf", file_name, "-C", target_directory] + # 启动子进程执行解压命令 + await asyncio.create_subprocess_exec(*command) + # 不等待子进程完成 + log.info(f"extract_tar_gz ing {file_name}") diff --git a/xiaomusic/utils/text_utils.py b/xiaomusic/utils/text_utils.py new file mode 100644 index 0000000..79dd0e6 --- /dev/null +++ b/xiaomusic/utils/text_utils.py @@ -0,0 +1,264 @@ +#!/usr/bin/env python3 +"""文本处理和搜索相关工具函数""" + +import difflib +import re +from collections.abc import AsyncIterator + +from opencc import OpenCC + +# 繁简转换器 +cc = OpenCC("t2s") + +# TTS 相关正则 +_no_elapse_chars = re.compile(r"([「」『』《》" "'\"()()]|(? float: + """计算 TTS 语音时长""" + # for simplicity, we use a fixed speed + speed = 4.5 # this value is picked by trial and error + # Exclude quotes and brackets that do not affect the total elapsed time + return len(_no_elapse_chars.sub("", text)) / speed + + +async def split_sentences(text_stream: AsyncIterator[str]) -> AsyncIterator[str]: + """分句处理,按标点符号分割""" + cur = "" + async for text in text_stream: + cur += text + if cur.endswith(_ending_punctuations): + yield cur + cur = "" + if cur: + yield cur + + +def find_key_by_partial_string(dictionary: dict[str, str], partial_key: str) -> str: + """通过部分字符串查找字典中的键""" + for key, value in dictionary.items(): + if key in partial_key: + return value + return None + + +def traditional_to_simple(to_convert: str) -> str: + """繁体转简体""" + return cc.convert(to_convert) + + +def keyword_detection(user_input: str, str_list: list, n: int) -> tuple[list, list]: + """ + 关键词检测 + + Args: + user_input: 用户输入 + str_list: 候选字符串列表 + n: 返回匹配数量,-1 表示返回所有 + + Returns: + (匹配列表, 剩余列表) + """ + # 过滤包含关键字的字符串 + matched, remains = [], [] + for item in str_list: + if user_input in item: + matched.append(item) + else: + remains.append(item) + + matched = sorted( + matched, + key=lambda s: difflib.SequenceMatcher(None, s, user_input).ratio(), + reverse=True, # 降序排序,越相似的越靠前 + ) + + # 如果 n 是 -1,如果 n 大于匹配的数量,返回所有匹配的结果 + if n == -1 or n > len(matched): + return matched, remains + + # 选择前 n 个匹配的结果 + remains = matched[n:] + remains + return matched[:n], remains + + +def real_search(prompt: str, candidates: list, cutoff: float, n: int) -> list: + """实际搜索逻辑""" + matches, remains = keyword_detection(prompt, candidates, n=n) + if len(matches) < n: + # 如果没有准确关键词匹配,开始模糊匹配 + matches += difflib.get_close_matches(prompt, remains, n=n, cutoff=cutoff) + return matches + + +def find_best_match( + user_input: str, + collection: list, + cutoff: float = 0.6, + n: int = 1, + extra_search_index: dict = None, +) -> list: + """ + 查找最佳匹配 + + Args: + user_input: 用户输入 + collection: 候选集合 + cutoff: 相似度阈值 + n: 返回数量 + extra_search_index: 额外搜索索引 + + Returns: + 匹配结果列表 + """ + lower_collection = { + traditional_to_simple(item.lower()): item for item in collection + } + user_input = traditional_to_simple(user_input.lower()) + matches = real_search(user_input, list(lower_collection.keys()), cutoff, n) + cur_matched_collection = [lower_collection[match] for match in matches] + if len(matches) >= n or extra_search_index is None: + return cur_matched_collection[:n] + + # 如果数量不满足,继续搜索 + lower_extra_search_index = { + traditional_to_simple(k.lower()): v + for k, v in extra_search_index.items() + if v not in cur_matched_collection + } + matches = real_search(user_input, list(lower_extra_search_index.keys()), cutoff, n) + cur_matched_collection += [lower_extra_search_index[match] for match in matches] + return cur_matched_collection[:n] + + +def fuzzyfinder( + user_input: str, collection: list, extra_search_index: dict = None +) -> list: + """模糊搜索""" + return find_best_match( + user_input, collection, cutoff=0.1, n=10, extra_search_index=extra_search_index + ) + + +def custom_sort_key(s: str) -> tuple: + """ + 歌曲排序键函数 + + 支持数字前缀、数字后缀和字典序排序 + """ + # 使用正则表达式分别提取字符串的数字前缀和数字后缀 + prefix_match = re.match(r"^(\d+)", s) + suffix_match = re.search(r"(\d+)$", s) + + numeric_prefix = int(prefix_match.group(0)) if prefix_match else None + numeric_suffix = int(suffix_match.group(0)) if suffix_match else None + + if numeric_prefix is not None: + # 如果前缀是数字,先按前缀数字排序,再按整个字符串排序 + return (0, numeric_prefix, s) + elif numeric_suffix is not None: + # 如果后缀是数字,先按前缀字符排序,再按后缀数字排序 + return (1, s[: suffix_match.start()], numeric_suffix) + else: + # 如果前缀和后缀都不是数字,按字典序排序 + return (2, s) + + +def chinese_to_number(chinese: str) -> int: + """ + 中文数字转阿拉伯数字 + + Args: + chinese: 中文数字字符串,如 "一百二十三" + + Returns: + 对应的阿拉伯数字 + """ + result = 0 + unit = 1 + num = 0 + # 处理特殊情况:以"十"开头时,在前面加"一" + if chinese.startswith("十"): + chinese = "一" + chinese + + # 如果只有一个字符且是单位,直接返回其值 + if len(chinese) == 1 and chinese_to_arabic[chinese] >= 10: + return chinese_to_arabic[chinese] + + for char in reversed(chinese): + if char in chinese_to_arabic: + val = chinese_to_arabic[char] + if val >= 10: + if val > unit: + unit = val + else: + unit *= val + else: + num += val * unit + result += num + + return result + + +def parse_str_to_dict(s: str, d1: str = ",", d2: str = ":") -> dict: + """ + 解析字符串为字典 + + 格式: k1:v1,k2:v2 + + Args: + s: 待解析字符串 + d1: 第一级分隔符(默认逗号) + d2: 第二级分隔符(默认冒号) + + Returns: + 解析后的字典 + """ + result = {} + parts = s.split(d1) + + for part in parts: + # 根据冒号切割 + subparts = part.split(d2) + if len(subparts) == 2: # 防止数据不是成对出现 + k, v = subparts + result[k] = v + + return result + + +def list2str(li: list, verbose: bool = False) -> str: + """ + 列表转字符串展示 + + Args: + li: 列表 + verbose: 是否详细显示 + + Returns: + 格式化的字符串 + """ + if len(li) > 5 and not verbose: + return f"{li[:2]} ... {li[-2:]} with len: {len(li)}" + else: + return f"{li}"