1
0
mirror of https://github.com/hanxi/xiaomusic.git synced 2026-05-24 11:35:46 +08:00

refactor: 重构拆分 utils 文件

This commit is contained in:
涵曦
2026-01-05 20:23:42 +08:00
parent e4cf48e234
commit ddf2aef7b7
7 changed files with 2020 additions and 1502 deletions

File diff suppressed because it is too large Load Diff

125
xiaomusic/utils/__init__.py Normal file
View File

@@ -0,0 +1,125 @@
#!/usr/bin/env python3
"""
Utils package - 工具函数模块
将原 utils.py 拆分为多个职责清晰的子模块:
- text_utils: 文本处理和搜索
- file_utils: 文件和目录操作
- music_utils: 音乐文件处理
- network_utils: 网络请求和下载
- system_utils: 系统操作和环境
"""
# 从各子模块导入常用函数,保持向后兼容
from xiaomusic.utils.file_utils import (
chmoddir,
chmodfile,
not_in_dirs,
remove_common_prefix,
safe_join_path,
traverse_music_directory,
)
from xiaomusic.utils.music_utils import (
Metadata,
convert_file_to_mp3,
extract_audio_metadata,
get_duration_by_ffprobe,
get_duration_by_mutagen,
get_local_music_duration,
get_web_music_duration,
is_m4a,
is_mp3,
remove_id3_tags,
save_picture_by_base64,
set_music_tag_to_file,
)
from xiaomusic.utils.network_utils import (
MusicUrlCache,
check_bili_fav_list,
download_one_music,
download_playlist,
downloadfile,
fetch_json_get,
text_to_mp3,
)
from xiaomusic.utils.system_utils import (
deepcopy_data_no_sensitive_info,
download_and_extract,
get_latest_version,
get_os_architecture,
get_random,
is_docker,
parse_cookie_string,
restart_xiaomusic,
try_add_access_control_param,
update_version,
validate_proxy,
)
from xiaomusic.utils.text_utils import (
calculate_tts_elapse,
chinese_to_number,
custom_sort_key,
find_best_match,
find_key_by_partial_string,
fuzzyfinder,
keyword_detection,
list2str,
parse_str_to_dict,
split_sentences,
traditional_to_simple,
)
__all__ = [
# text_utils
"calculate_tts_elapse",
"chinese_to_number",
"custom_sort_key",
"find_best_match",
"find_key_by_partial_string",
"fuzzyfinder",
"keyword_detection",
"list2str",
"parse_str_to_dict",
"split_sentences",
"traditional_to_simple",
# file_utils
"chmoddir",
"chmodfile",
"not_in_dirs",
"remove_common_prefix",
"safe_join_path",
"traverse_music_directory",
# music_utils
"Metadata",
"convert_file_to_mp3",
"extract_audio_metadata",
"get_duration_by_ffprobe",
"get_duration_by_mutagen",
"get_local_music_duration",
"get_web_music_duration",
"is_m4a",
"is_mp3",
"remove_id3_tags",
"set_music_tag_to_file",
"save_picture_by_base64",
# network_utils
"MusicUrlCache",
"check_bili_fav_list",
"download_one_music",
"download_playlist",
"downloadfile",
"fetch_json_get",
"text_to_mp3",
# system_utils
"deepcopy_data_no_sensitive_info",
"download_and_extract",
"get_latest_version",
"get_os_architecture",
"get_random",
"is_docker",
"parse_cookie_string",
"restart_xiaomusic",
"try_add_access_control_param",
"update_version",
"validate_proxy",
]

View File

@@ -0,0 +1,189 @@
#!/usr/bin/env python3
"""文件和目录操作相关工具函数"""
import logging
import os
import re
log = logging.getLogger(__package__)
def _get_depth_path(root: str, directory: str, depth: int) -> str:
"""计算指定深度的路径"""
# 计算当前目录的深度
relative_path = root[len(directory) :].strip(os.sep)
path_parts = relative_path.split(os.sep)
if len(path_parts) >= depth:
return os.path.join(directory, *path_parts[:depth])
else:
return root
def _append_files_result(
result: dict, root: str, joinpath: str, files: list, support_extension: set
) -> None:
"""将文件添加到结果字典中"""
dir_name = os.path.basename(root)
if dir_name not in result:
result[dir_name] = []
for file in files:
# 过滤隐藏文件
if file.startswith("."):
continue
# 过滤文件后缀
(name, extension) = os.path.splitext(file)
if extension.lower() not in support_extension:
continue
result[dir_name].append(os.path.join(joinpath, file))
def traverse_music_directory(
directory: str, depth: int, exclude_dirs: set, support_extension: set
) -> dict:
"""
遍历音乐目录
Args:
directory: 目录路径
depth: 遍历深度
exclude_dirs: 排除的目录集合
support_extension: 支持的文件扩展名集合
Returns:
{目录名: [文件路径列表]}
"""
result = {}
for root, dirs, files in os.walk(directory, followlinks=True):
# 忽略排除的目录
dirs[:] = [d for d in dirs if d not in exclude_dirs]
# 计算当前目录的深度
current_depth = root[len(directory) :].count(os.sep) + 1
if current_depth > depth:
depth_path = _get_depth_path(root, directory, depth - 1)
_append_files_result(result, depth_path, root, files, support_extension)
else:
_append_files_result(result, root, root, files, support_extension)
return result
def safe_join_path(safe_root: str, directory: str) -> str:
"""
安全地拼接路径,确保结果在安全根目录内
Args:
safe_root: 安全根目录
directory: 要拼接的目录
Returns:
规范化的完整路径
Raises:
ValueError: 如果路径不在安全根目录内
"""
directory = os.path.join(safe_root, directory)
# Normalize the directory path
normalized_directory = os.path.normpath(directory)
# Ensure the directory is within the safe root
if not normalized_directory.startswith(os.path.normpath(safe_root)):
raise ValueError(f"Access to directory '{directory}' is not allowed.")
return normalized_directory
def _longest_common_prefix(file_names: list) -> str:
"""查找文件名列表的最长公共前缀"""
if not file_names:
return ""
# 将第一个文件名作为初始前缀
prefix = file_names[0]
for file_name in file_names[1:]:
while not file_name.startswith(prefix):
# 如果当前文件名不以prefix开头则缩短prefix
prefix = prefix[:-1]
if not prefix:
return ""
return prefix
def remove_common_prefix(directory: str) -> None:
"""
移除目录下文件名的公共前缀
Args:
directory: 目录路径
"""
files = os.listdir(directory)
# 获取所有文件的前缀
common_prefix = _longest_common_prefix(files)
log.info(f'Common prefix identified: "{common_prefix}"')
pattern = re.compile(r"^[pP]?(\d+)\s+\d*(.+?)\.(.*$)")
for filename in files:
if filename == common_prefix:
continue
# 检查文件名是否以共同前缀开头
if filename.startswith(common_prefix):
# 构造新的文件名
new_filename = filename[len(common_prefix) :]
match = pattern.search(new_filename.strip())
if match:
num = match.group(1)
name = match.group(2).replace(".", " ").strip()
suffix = match.group(3)
new_filename = f"{num}.{name}.{suffix}"
# 生成完整的文件路径
old_file_path = os.path.join(directory, filename)
new_file_path = os.path.join(directory, new_filename)
# 重命名文件
os.rename(old_file_path, new_file_path)
log.debug(f'Renamed: "{filename}" to "{new_filename}"')
def not_in_dirs(filename: str, ignore_absolute_dirs: list) -> bool:
"""
判断文件是否不在排除目录列表中
Args:
filename: 文件路径
ignore_absolute_dirs: 要忽略的绝对路径列表
Returns:
True 如果文件不在排除目录中
"""
file_absolute_path = os.path.abspath(filename)
file_dir = os.path.dirname(file_absolute_path)
for ignore_dir in ignore_absolute_dirs:
if file_dir.startswith(ignore_dir):
log.info(f"{file_dir} in {ignore_dir}")
return False # 文件在排除目录中
return True # 文件不在排除目录中
def chmodfile(file_path: str) -> None:
"""修改文件权限为 775"""
try:
os.chmod(file_path, 0o775)
except Exception as e:
log.info(f"chmodfile failed: {e}")
def chmoddir(dir_path: str) -> None:
"""修改目录下所有文件的权限为 775"""
# 获取指定目录下的所有文件和子目录
for item in os.listdir(dir_path):
item_path = os.path.join(dir_path, item)
# 确保是文件,且不是目录
if os.path.isfile(item_path):
try:
os.chmod(item_path, 0o775)
log.info(f"Changed permissions of file: {item_path}")
except Exception as e:
log.info(f"chmoddir failed: {e}")

View File

@@ -0,0 +1,695 @@
#!/usr/bin/env python3
"""音乐文件处理相关工具函数"""
import asyncio
import base64
import hashlib
import io
import json
import logging
import mimetypes
import os
import shutil
import subprocess
import tempfile
from dataclasses import asdict, dataclass
from urllib.parse import urlparse
import aiohttp
import mutagen
from mutagen.asf import ASF
from mutagen.flac import FLAC
from mutagen.id3 import (
APIC,
ID3,
TALB,
TCON,
TDRC,
TIT2,
TPE1,
USLT,
Encoding,
TextFrame,
TimeStampTextFrame,
)
from mutagen.mp3 import MP3
from mutagen.mp4 import MP4
from mutagen.oggvorbis import OggVorbis
from mutagen.wave import WAVE
from mutagen.wavpack import WavPack
from PIL import Image
from xiaomusic.const import SUPPORT_MUSIC_TYPE
log = logging.getLogger(__package__)
@dataclass
class Metadata:
"""音乐元数据"""
title: str = ""
artist: str = ""
album: str = ""
year: str = ""
genre: str = ""
picture: str = ""
lyrics: str = ""
def __init__(self, info=None):
if info:
self.title = info.get("title", "")
self.artist = info.get("artist", "")
self.album = info.get("album", "")
self.year = info.get("year", "")
self.genre = info.get("genre", "")
self.picture = info.get("picture", "")
self.lyrics = info.get("lyrics", "")
def is_mp3(url: str) -> bool:
"""判断是否为 MP3 文件"""
mt = mimetypes.guess_type(url)
if mt and mt[0] == "audio/mpeg":
return True
return False
def is_m4a(url: str) -> bool:
"""判断是否为 M4A 文件"""
return url.endswith(".m4a")
async def _get_web_music_duration(
session, url: str, config, start: int = 0, end: int = 500
) -> float:
"""
异步获取网络音乐文件的部分内容并估算其时长
通过请求 URL 的前几个字节(默认 0-500下载部分文件
写入临时文件后调用本地工具(如 ffprobe获取音频时长
Args:
session: aiohttp.ClientSession 实例
url: 音乐文件的 URL 地址
config: 包含配置信息的对象(如 ffmpeg 路径)
start: 请求的起始字节位置
end: 请求的结束字节位置
Returns:
返回音频的持续时间(秒),如果失败则返回 0
"""
duration = 0
# 设置请求头 Range只请求部分内容用于快速获取元数据
headers = {"Range": f"bytes={start}-{end}"}
# 使用 aiohttp 异步发起 GET 请求,获取部分音频内容
async with session.get(url, headers=headers) as response:
array_buffer = await response.read() # 读取响应的二进制内容
# 创建一个命名的临时文件,并禁用自动删除(以便后续读取)
with tempfile.NamedTemporaryFile(delete=False) as tmp:
tmp.write(array_buffer) # 将下载的部分内容写入临时文件
tmp_path = tmp.name # 获取该临时文件的真实路径
try:
# 调用 get_local_music_duration 并传入文件路径,而不是文件对象
duration = await get_local_music_duration(tmp_path, config)
except Exception as e:
log.error(f"Error _get_web_music_duration: {e}")
finally:
# 手动删除临时文件,避免残留
os.unlink(tmp_path)
return duration
async def get_web_music_duration(url: str, config) -> tuple[float, str]:
"""
获取网络音乐时长
Args:
url: 音乐 URL
config: 配置对象
Returns:
(时长(秒), 最终URL)
"""
duration = 0
try:
parsed_url = urlparse(url)
file_path = parsed_url.path
_, extension = os.path.splitext(file_path)
if extension.lower() not in SUPPORT_MUSIC_TYPE:
cleaned_url = parsed_url.geturl()
async with aiohttp.ClientSession() as session:
async with session.get(
cleaned_url,
allow_redirects=True,
headers={
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_10_1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/39.0.2171.95 Safari/537.36"
},
) as response:
url = str(response.url)
# 设置总超时时间为3秒
timeout = aiohttp.ClientTimeout(total=3)
async with aiohttp.ClientSession(timeout=timeout) as session:
duration = await _get_web_music_duration(
session, url, config, start=0, end=500
)
if duration <= 0:
duration = await _get_web_music_duration(
session, url, config, start=0, end=3000
)
except Exception as e:
log.error(f"Error get_web_music_duration: {e}")
return duration, url
async def get_local_music_duration(filename: str, config) -> float:
"""
获取本地音乐文件播放时长
Args:
filename: 文件路径
config: 配置对象
Returns:
时长(秒)
"""
duration = 0
if config.get_duration_type == "ffprobe":
duration = get_duration_by_ffprobe(filename, config.ffmpeg_location)
else:
duration = await get_duration_by_mutagen(filename)
# 换个方式重试一次
if duration == 0:
if config.get_duration_type != "ffprobe":
duration = get_duration_by_ffprobe(filename, config.ffmpeg_location)
else:
duration = await get_duration_by_mutagen(filename)
return duration
async def get_duration_by_mutagen(file_path: str) -> float:
"""使用 mutagen 获取音乐时长"""
duration = 0
try:
loop = asyncio.get_event_loop()
if is_mp3(file_path):
m = await loop.run_in_executor(None, mutagen.mp3.MP3, file_path)
else:
m = await loop.run_in_executor(None, mutagen.File, file_path)
duration = m.info.length
except Exception as e:
log.warning(f"Error getting local music {file_path} duration: {e}")
return duration
def get_duration_by_ffprobe(file_path: str, ffmpeg_location: str) -> float:
"""使用 ffprobe 获取音乐时长"""
duration = 0
try:
# 构造 ffprobe 命令参数
cmd_args = [
os.path.join(ffmpeg_location, "ffprobe"),
"-v",
"error", # 只输出错误信息,避免混杂在其他输出中
"-show_entries",
"format=duration", # 仅显示时长
"-of",
"json", # 以 JSON 格式输出
file_path,
]
# 输出待执行的完整命令
full_command = " ".join(cmd_args)
log.info(f"待执行的完整命令 ffprobe command: {full_command}")
# 使用 ffprobe 获取文件的元数据,并以 JSON 格式输出
result = subprocess.run(
cmd_args,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
)
# 输出命令执行结果
log.info(
f"命令执行结果 command result - return code: {result.returncode}, stdout: {result.stdout}"
)
# 解析 JSON 输出
ffprobe_output = json.loads(result.stdout)
# 获取时长
duration = float(ffprobe_output["format"]["duration"])
log.info(
f"Successfully extracted duration: {duration} seconds for file: {file_path}"
)
except Exception as e:
log.warning(f"Error getting local music {file_path} duration: {e}")
return duration
def no_padding(info) -> int:
"""移除 MP3 文件的 padding"""
# this will remove all padding
return 0
def remove_id3_tags(input_file: str, config) -> str:
"""
移除 MP3 文件的 ID3 标签以减少延迟
Args:
input_file: 输入文件路径
config: 配置对象
Returns:
处理后的相对路径,如果无需处理则返回 None
"""
audio = MP3(input_file, ID3=ID3)
# 检查是否存在ID3 v2.3或v2.4标签
if not (
audio.tags
and (audio.tags.version == (2, 3, 0) or audio.tags.version == (2, 4, 0))
):
return None
music_path = config.music_path
temp_dir = config.temp_dir
# 构造新文件的路径
out_file_name = os.path.splitext(os.path.basename(input_file))[0]
out_file_path = os.path.join(temp_dir, f"{out_file_name}.mp3")
relative_path = os.path.relpath(out_file_path, music_path)
# 路径相同的情况
input_absolute_path = os.path.abspath(input_file)
output_absolute_path = os.path.abspath(out_file_path)
if input_absolute_path == output_absolute_path:
log.info(f"File {input_file} = {out_file_path} . Skipping remove_id3_tags.")
return None
# 检查目标文件是否存在
if os.path.exists(out_file_path):
log.info(f"File {out_file_path} already exists. Skipping remove_id3_tags.")
return relative_path
# 开始去除(不再需要检查)
# 拷贝文件
shutil.copy(input_file, out_file_path)
outaudio = MP3(out_file_path, ID3=ID3)
# 删除ID3标签
outaudio.delete()
# 保存修改后的文件
outaudio.save(padding=no_padding)
log.info(f"File {out_file_path} remove_id3_tags ok.")
return relative_path
def convert_file_to_mp3(input_file: str, config) -> str:
"""
转换音频文件为 MP3 格式
Args:
input_file: 输入文件路径
config: 配置对象
Returns:
转换后的相对路径,如果无需转换则返回 None
"""
music_path = config.music_path
temp_dir = config.temp_dir
out_file_name = os.path.splitext(os.path.basename(input_file))[0]
out_file_path = os.path.join(temp_dir, f"{out_file_name}.mp3")
relative_path = os.path.relpath(out_file_path, music_path)
# 路径相同的情况
input_absolute_path = os.path.abspath(input_file)
output_absolute_path = os.path.abspath(out_file_path)
if input_absolute_path == output_absolute_path:
log.info(f"File {input_file} = {out_file_path} . Skipping convert_file_to_mp3.")
return None
absolute_music_path = os.path.abspath(music_path)
if not input_absolute_path.startswith(absolute_music_path):
log.error(f"Invalid input file path: {input_file}")
return None
# 检查目标文件是否存在
if os.path.exists(out_file_path):
log.info(f"File {out_file_path} already exists. Skipping convert_file_to_mp3.")
return relative_path
# 检查是否存在 loudnorm 参数
loudnorm_args = []
if config.loudnorm:
loudnorm_args = ["-af", config.loudnorm]
command = [
os.path.join(config.ffmpeg_location, "ffmpeg"),
"-i",
input_absolute_path,
"-f",
"mp3",
"-vn",
"-y",
*loudnorm_args,
out_file_path,
]
try:
subprocess.run(command, check=True)
except subprocess.CalledProcessError as e:
log.exception(f"Error during conversion: {e}")
return None
log.info(f"File {input_file} to {out_file_path} convert_file_to_mp3 ok.")
return relative_path
def _to_utf8(v):
"""转换标签值为 UTF-8 字符串"""
if isinstance(v, TextFrame) and not isinstance(v, TimeStampTextFrame):
old_ts = "".join(v.text)
if v.encoding == Encoding.LATIN1:
bs = old_ts.encode("latin1")
ts = bs.decode("GBK", errors="ignore")
return ts
return old_ts
elif isinstance(v, list):
return "".join(str(item) for item in v)
return str(v)
def _get_tag_value(tags, k: str) -> str:
"""获取标签值"""
if k not in tags:
return ""
v = tags[k]
return _to_utf8(v)
def _get_alltag_value(tags, k: str) -> str:
"""获取所有标签值"""
v = tags.getall(k)
if len(v) > 0:
return _to_utf8(v[0])
return ""
def _save_picture(picture_data: bytes, save_root: str, file_path: str) -> str:
"""保存图片"""
# 计算文件名的哈希值
file_hash = hashlib.md5(file_path.encode("utf-8")).hexdigest()
# 创建目录结构
dir_path = os.path.join(save_root, file_hash[-6:])
os.makedirs(dir_path, exist_ok=True)
# 保存图片
filename = os.path.basename(file_path)
(name, _) = os.path.splitext(filename)
picture_path = os.path.join(dir_path, f"{name}.jpg")
try:
_resize_save_image(picture_data, picture_path)
except Exception as e:
log.warning(f"Error _resize_save_image: {e}")
return picture_path
def _resize_save_image(image_bytes: bytes, save_path: str, max_size: int = 300) -> str:
"""缩放并保存图片"""
# 将 bytes 转换为 PIL Image 对象
image = None
try:
image = Image.open(io.BytesIO(image_bytes))
image = image.convert("RGB")
except Exception as e:
log.warning(f"Error _resize_save_image: {e}")
return None
# 获取原始尺寸
original_width, original_height = image.size
# 如果图片的宽度和高度都小于 max_size则直接保存原始图片
if original_width <= max_size and original_height <= max_size:
image.save(save_path, format="JPEG")
return save_path
# 计算缩放比例,保持等比缩放
scaling_factor = min(max_size / original_width, max_size / original_height)
# 计算新的尺寸
new_width = int(original_width * scaling_factor)
new_height = int(original_height * scaling_factor)
resized_image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
resized_image.save(save_path, format="JPEG")
return save_path
def save_picture_by_base64(
picture_base64_data: str, save_root: str, file_path: str
) -> str:
"""通过 base64 数据保存图片"""
try:
picture_data = base64.b64decode(picture_base64_data)
except (TypeError, ValueError) as e:
log.exception(f"Error decoding base64 data: {e}")
return None
return _save_picture(picture_data, save_root, file_path)
def extract_audio_metadata(file_path: str, save_root: str) -> dict:
"""
提取音频文件的元数据
Args:
file_path: 音频文件路径
save_root: 图片保存根目录
Returns:
元数据字典
"""
metadata = Metadata()
audio = None
try:
audio = mutagen.File(file_path)
except Exception as e:
log.warning(f"Error extract_audio_metadata file: {file_path} {e}")
if audio is None:
return asdict(metadata)
tags = audio.tags
if tags is None:
return asdict(metadata)
if isinstance(audio, MP3):
metadata.title = _get_tag_value(tags, "TIT2")
metadata.artist = _get_tag_value(tags, "TPE1")
metadata.album = _get_tag_value(tags, "TALB")
metadata.year = _get_tag_value(tags, "TDRC")
metadata.genre = _get_tag_value(tags, "TCON")
metadata.lyrics = _get_alltag_value(tags, "USLT")
for tag in tags.values():
if isinstance(tag, APIC):
metadata.picture = _save_picture(tag.data, save_root, file_path)
break
elif isinstance(audio, FLAC):
metadata.title = _get_tag_value(tags, "TITLE")
metadata.artist = _get_tag_value(tags, "ARTIST")
metadata.album = _get_tag_value(tags, "ALBUM")
metadata.year = _get_tag_value(tags, "DATE")
metadata.genre = _get_tag_value(tags, "GENRE")
if audio.pictures:
metadata.picture = _save_picture(
audio.pictures[0].data, save_root, file_path
)
if "lyrics" in audio:
metadata.lyrics = audio["lyrics"][0]
elif isinstance(audio, MP4):
metadata.title = _get_tag_value(tags, "\xa9nam")
metadata.artist = _get_tag_value(tags, "\xa9ART")
metadata.album = _get_tag_value(tags, "\xa9alb")
metadata.year = _get_tag_value(tags, "\xa9day")
metadata.genre = _get_tag_value(tags, "\xa9gen")
if "covr" in tags:
metadata.picture = _save_picture(tags["covr"][0], save_root, file_path)
elif isinstance(audio, OggVorbis):
metadata.title = _get_tag_value(tags, "TITLE")
metadata.artist = _get_tag_value(tags, "ARTIST")
metadata.album = _get_tag_value(tags, "ALBUM")
metadata.year = _get_tag_value(tags, "DATE")
metadata.genre = _get_tag_value(tags, "GENRE")
if "metadata_block_picture" in tags:
picture = json.loads(base64.b64decode(tags["metadata_block_picture"][0]))
metadata.picture = _save_picture(
base64.b64decode(picture["data"]), save_root, file_path
)
elif isinstance(audio, ASF):
metadata.title = _get_tag_value(tags, "Title")
metadata.artist = _get_tag_value(tags, "Author")
metadata.album = _get_tag_value(tags, "WM/AlbumTitle")
metadata.year = _get_tag_value(tags, "WM/Year")
metadata.genre = _get_tag_value(tags, "WM/Genre")
if "WM/Picture" in tags:
metadata.picture = _save_picture(
tags["WM/Picture"][0].value, save_root, file_path
)
elif isinstance(audio, WavPack):
metadata.title = _get_tag_value(tags, "Title")
metadata.artist = _get_tag_value(tags, "Artist")
metadata.album = _get_tag_value(tags, "Album")
metadata.year = _get_tag_value(tags, "Year")
metadata.genre = _get_tag_value(tags, "Genre")
if audio.pictures:
metadata.picture = _save_picture(
audio.pictures[0].data, save_root, file_path
)
elif isinstance(audio, WAVE):
metadata.title = _get_tag_value(tags, "Title")
metadata.artist = _get_tag_value(tags, "Artist")
return asdict(metadata)
def _set_mp3_tags(audio, info: Metadata) -> None:
"""设置 MP3 标签"""
audio.tags = ID3()
audio["TIT2"] = TIT2(encoding=3, text=info.title)
audio["TPE1"] = TPE1(encoding=3, text=info.artist)
audio["TALB"] = TALB(encoding=3, text=info.album)
audio["TDRC"] = TDRC(encoding=3, text=info.year)
audio["TCON"] = TCON(encoding=3, text=info.genre)
# 使用 USLT 存储歌词
if info.lyrics:
audio["USLT"] = USLT(encoding=3, lang="eng", text=info.lyrics)
# 添加封面图片
if info.picture:
with open(info.picture, "rb") as img_file:
image_data = img_file.read()
audio["APIC"] = APIC(
encoding=3, mime="image/jpeg", type=3, desc="Cover", data=image_data
)
audio.save() # 保存修改
def _set_flac_tags(audio, info: Metadata) -> None:
"""设置 FLAC 标签"""
audio["TITLE"] = info.title
audio["ARTIST"] = info.artist
audio["ALBUM"] = info.album
audio["DATE"] = info.year
audio["GENRE"] = info.genre
if info.lyrics:
audio["LYRICS"] = info.lyrics
if info.picture:
with open(info.picture, "rb") as img_file:
image_data = img_file.read()
audio.add_picture(image_data)
def _set_mp4_tags(audio, info: Metadata) -> None:
"""设置 MP4 标签"""
audio["nam"] = info.title
audio["ART"] = info.artist
audio["alb"] = info.album
audio["day"] = info.year
audio["gen"] = info.genre
if info.picture:
with open(info.picture, "rb") as img_file:
image_data = img_file.read()
audio["covr"] = [image_data]
def _set_ogg_tags(audio, info: Metadata) -> None:
"""设置 OGG 标签"""
audio["TITLE"] = info.title
audio["ARTIST"] = info.artist
audio["ALBUM"] = info.album
audio["DATE"] = info.year
audio["GENRE"] = info.genre
if info.lyrics:
audio["LYRICS"] = info.lyrics
if info.picture:
with open(info.picture, "rb") as img_file:
image_data = img_file.read()
audio["metadata_block_picture"] = base64.b64encode(image_data).decode()
def _set_asf_tags(audio, info: Metadata) -> None:
"""设置 ASF 标签"""
audio["Title"] = info.title
audio["Author"] = info.artist
audio["WM/AlbumTitle"] = info.album
audio["WM/Year"] = info.year
audio["WM/Genre"] = info.genre
if info.picture:
with open(info.picture, "rb") as img_file:
image_data = img_file.read()
audio["WM/Picture"] = image_data
def _set_wave_tags(audio, info: Metadata) -> None:
"""设置 WAVE 标签"""
audio["Title"] = info.title
audio["Artist"] = info.artist
def set_music_tag_to_file(file_path: str, info: Metadata) -> str:
"""
设置音乐文件的标签信息
Args:
file_path: 文件路径
info: 元数据对象
Returns:
"OK" 或错误信息
"""
audio = mutagen.File(file_path, easy=True)
if audio is None:
log.error(f"Unable to open file {file_path}")
return "Unable to open file"
if isinstance(audio, MP3):
_set_mp3_tags(audio, info)
elif isinstance(audio, FLAC):
_set_flac_tags(audio, info)
elif isinstance(audio, MP4):
_set_mp4_tags(audio, info)
elif isinstance(audio, OggVorbis):
_set_ogg_tags(audio, info)
elif isinstance(audio, ASF):
_set_asf_tags(audio, info)
elif isinstance(audio, WAVE):
_set_wave_tags(audio, info)
else:
log.error(f"Unsupported file type for {file_path}")
return "Unsupported file type"
try:
audio.save()
log.info(f"Tags saved successfully to {file_path}")
return "OK"
except Exception as e:
log.exception(f"Error saving tags: {e}")
return "Error saving tags"

View File

@@ -0,0 +1,448 @@
#!/usr/bin/env python3
"""网络请求和下载相关工具函数"""
import asyncio
import hashlib
import logging
import os
import time
from collections import OrderedDict
from pathlib import Path
from time import sleep
from urllib.parse import parse_qs, urlparse
import aiohttp
import edge_tts
log = logging.getLogger(__package__)
async def downloadfile(url: str) -> str:
"""
下载文件内容
Args:
url: 文件 URL
Returns:
文件文本内容
Raises:
Warning: 如果 URL 协议不是 HTTP/HTTPS
"""
# 清理和验证URL
# 解析URL
parsed_url = urlparse(url)
# 基础验证仅允许HTTP和HTTPS协议
if parsed_url.scheme not in ("http", "https"):
raise Warning(
f"Invalid URL scheme: {parsed_url.scheme}. Only HTTP and HTTPS are allowed."
)
# 构建目标URL
cleaned_url = parsed_url.geturl()
# 使用 aiohttp 创建一个客户端会话来发起请求
async with aiohttp.ClientSession() as session:
async with session.get(
cleaned_url, timeout=5
) as response: # 增加超时以避免长时间挂起
# 如果响应不是200引发异常
response.raise_for_status()
# 读取响应文本
text = await response.text()
return text
async def check_bili_fav_list(url: str) -> dict:
"""
检查 B 站收藏夹/合集
Args:
url: B站收藏夹或合集 URL
Returns:
{bvid/url: title} 字典
Raises:
ValueError: 如果不支持的类型
Exception: 如果请求失败
"""
bvid_info = {}
parsed_url = urlparse(url)
path = parsed_url.path
# 提取查询参数
query_params = parse_qs(parsed_url.query)
if parsed_url.hostname == "space.bilibili.com":
if "/favlist" in path:
lid = query_params.get("fid", [None])[0]
type = query_params.get("ctype", [None])[0]
if type == "11":
type = "create"
elif type == "21":
type = "collect"
else:
raise ValueError("当前只支持合集和收藏夹")
elif "/lists/" in path:
parts = path.split("/")
if len(parts) >= 4 and "?" in url:
lid = parts[3] # 提取 lid
type = query_params.get("type", [None])[0]
# https://api.bilibili.com/x/polymer/web-space/seasons_archives_list?season_id={lid}&page_size=30&page_num=1
page_size = 100
page_num = 1
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0 Safari/537.36",
"Accept": "application/json, text/plain, */*",
"Referer": url,
"Origin": "https://space.bilibili.com",
}
async with aiohttp.ClientSession(headers=headers) as session:
if type == "season" or type == "collect":
while True:
list_url = f"https://api.bilibili.com/x/polymer/web-space/seasons_archives_list?season_id={lid}&page_size={page_size}&page_num={page_num}"
async with session.get(list_url) as response:
if response.status != 200:
raise Exception(f"Failed to fetch data from {list_url}")
data = await response.json()
archives = data.get("data", {}).get("archives", [])
if not archives:
break
for archive in archives:
bvid = archive.get("bvid", None)
title = archive.get("title", None)
bvid_info[bvid] = title
if len(archives) < page_size:
break
page_num += 1
sleep(1)
elif type == "create":
while True:
list_url = f"https://api.bilibili.com/x/v3/fav/resource/list?media_id={lid}&pn={page_num}&ps={page_size}&order=mtime"
async with session.get(list_url) as response:
if response.status != 200:
raise Exception(f"Failed to fetch data from {list_url}")
data = await response.json()
medias = data.get("data", {}).get("medias", [])
if not medias:
break
for media in medias:
bvid = media.get("bvid", None)
title = media.get("title", None)
bvurl = f"https://www.bilibili.com/video/{bvid}"
bvid_info[bvurl] = title
if len(medias) < page_size:
break
page_num += 1
else:
raise ValueError("当前只支持合集和收藏夹")
return bvid_info
async def download_playlist(config, url: str, dirname: str):
"""
下载播放列表
Args:
config: 配置对象
url: 播放列表 URL
dirname: 保存目录名
Returns:
下载进程对象
"""
title = f"{dirname}/%(title)s.%(ext)s"
sbp_args = (
"yt-dlp",
"--yes-playlist",
"-x",
"--audio-format",
"mp3",
"--audio-quality",
"0",
"--paths",
config.download_path,
"-o",
title,
"--ffmpeg-location",
f"{config.ffmpeg_location}",
)
if config.proxy:
sbp_args += ("--proxy", f"{config.proxy}")
if config.enable_yt_dlp_cookies:
sbp_args += ("--cookies", f"{config.yt_dlp_cookies_path}")
if config.loudnorm:
sbp_args += ("--postprocessor-args", f"-af {config.loudnorm}")
sbp_args += (url,)
cmd = " ".join(sbp_args)
log.info(f"download_playlist: {cmd}")
download_proc = await asyncio.create_subprocess_exec(*sbp_args)
return download_proc
async def download_one_music(config, url: str, name: str = ""):
"""
下载单首歌曲
Args:
config: 配置对象
url: 歌曲 URL
name: 文件名(可选)
Returns:
下载进程对象
"""
title = "%(title)s.%(ext)s"
if name:
title = f"{name}.%(ext)s"
sbp_args = (
"yt-dlp",
"--no-playlist",
"-x",
"--audio-format",
"mp3",
"--audio-quality",
"0",
"--paths",
config.download_path,
"-o",
title,
"--ffmpeg-location",
f"{config.ffmpeg_location}",
)
if config.proxy:
sbp_args += ("--proxy", f"{config.proxy}")
if config.enable_yt_dlp_cookies:
sbp_args += ("--cookies", f"{config.yt_dlp_cookies_path}")
if config.loudnorm:
sbp_args += ("--postprocessor-args", f"-af {config.loudnorm}")
sbp_args += (url,)
cmd = " ".join(sbp_args)
log.info(f"download_one_music: {cmd}")
download_proc = await asyncio.create_subprocess_exec(*sbp_args)
return download_proc
async def fetch_json_get(url: str, headers: dict, config) -> dict:
"""
发起 GET 请求获取 JSON 数据
Args:
url: 请求 URL
headers: 请求头
config: 配置对象(用于代理设置)
Returns:
JSON 响应数据字典
"""
connector = None
proxy = None
if config and config.proxy:
connector = aiohttp.TCPConnector(
ssl=False, # 如需验证SSL证书可改为True需确保代理支持
limit=10,
)
proxy = config.proxy
try:
# 2. 传入代理配置创建ClientSession
async with aiohttp.ClientSession(connector=connector) as session:
# 3. 发起带代理的GET请求
async with session.get(
url,
headers=headers,
proxy=proxy, # 传入格式化后的代理参数
timeout=10, # 超时时间(秒),避免无限等待
) as response:
if response.status == 200:
data = await response.json()
log.info(f"fetch_json_get: {url} success {data}")
# 确保返回结果为dict
if isinstance(data, dict):
return data
else:
log.warning(f"Expected dict, but got {type(data)}: {data}")
return {}
else:
log.error(f"HTTP Error: {response.status} {url}")
return {}
except aiohttp.ClientError as e:
log.error(f"ClientError fetching {url} (proxy: {proxy}): {e}")
return {}
except asyncio.TimeoutError:
log.error(f"Timeout fetching {url} (proxy: {proxy})")
return {}
except Exception as e:
log.error(f"Unexpected error fetching {url} (proxy: {proxy}): {e}")
return {}
finally:
# 4. 关闭连接器(避免资源泄漏)
if connector and not connector.closed:
await connector.close()
class LRUCache(OrderedDict):
"""LRU 缓存实现"""
def __init__(self, max_size: int = 1000):
super().__init__()
self.max_size = max_size
def __setitem__(self, key, value):
if key in self:
# 移动到末尾(最近使用)
self.move_to_end(key)
super().__setitem__(key, value)
# 如果超出大小限制,删除最早使用的项
if len(self) > self.max_size:
self.popitem(last=False)
def __getitem__(self, key):
# 访问时移动到末尾(最近使用)
if key in self:
self.move_to_end(key)
return super().__getitem__(key)
class MusicUrlCache:
"""音乐 URL 缓存管理器"""
def __init__(self, default_expire_days: int = 1, max_size: int = 1000):
self.cache = LRUCache(max_size)
self.default_expire_days = default_expire_days
self.log = logging.getLogger(__name__)
async def get(self, url: str, headers: dict = None, config=None) -> str:
"""
获取URL(优先从缓存获取,没有则请求API)
Args:
url: 原始URL
headers: API请求需要的headers
config: 配置对象
Returns:
str: 真实播放URL
"""
# 先查询缓存
cached_url = self._get_from_cache(url)
if cached_url:
self.log.info(f"Using cached url: {cached_url}")
return cached_url
# 缓存未命中,请求API
return await self._fetch_from_api(url, headers, config)
def _get_from_cache(self, url: str) -> str:
"""从缓存中获取URL"""
try:
cached_url, expire_time = self.cache[url]
if time.time() > expire_time:
# 缓存过期,删除
del self.cache[url]
return ""
return cached_url
except KeyError:
return ""
async def _fetch_from_api(self, url: str, headers: dict = None, config=None) -> str:
"""从API获取真实URL"""
data = await fetch_json_get(url, headers or {}, config)
if not isinstance(data, dict):
self.log.error(f"Invalid API response format: {data}")
return ""
real_url = data.get("url")
if not real_url:
self.log.error(f"No url in API response: {data}")
return ""
# 获取过期时间
expire_time = self._parse_expire_time(data)
# 缓存结果
self._set_cache(url, real_url, expire_time)
self.log.info(
f"Cached url, expire_time: {expire_time}, cache size: {len(self.cache)}"
)
return real_url
def _parse_expire_time(self, data: dict) -> float | None:
"""解析API返回的过期时间"""
try:
extra = data.get("extra", {})
expire_info = extra.get("expire", {})
if expire_info and expire_info.get("canExpire"):
expire_time = expire_info.get("time")
if expire_time:
return float(expire_time)
except Exception as e:
self.log.warning(f"Failed to parse expire time: {e}")
return None
def _set_cache(self, url: str, real_url: str, expire_time: float = None):
"""设置缓存"""
if expire_time is None:
expire_time = time.time() + (self.default_expire_days * 24 * 3600)
self.cache[url] = (real_url, expire_time)
def clear(self):
"""清空缓存"""
self.cache.clear()
@property
def size(self) -> int:
"""当前缓存大小"""
return len(self.cache)
async def text_to_mp3(
text: str, save_dir: str, voice: str = "zh-CN-XiaoxiaoNeural"
) -> str:
"""
使用edge-tts将文本转换为MP3语音文件
参数:
text: 需要转换的文本内容
save_dir: 保存MP3文件的目录路径
voice: 语音模型(默认中文晓晓)
返回:
str: 生成的MP3文件完整路径
"""
# 确保保存目录存在
Path(save_dir).mkdir(parents=True, exist_ok=True)
# 基于文本和语音模型生成唯一文件名(避免相同文本不同语音重复)
content = f"{text}_{voice}".encode()
file_hash = hashlib.md5(content).hexdigest()
mp3_filename = f"{file_hash}.mp3"
mp3_path = os.path.join(save_dir, mp3_filename)
# 文件已存在直接返回路径
if os.path.exists(mp3_path):
return mp3_path
# 调用edge-tts生成语音
try:
communicate = edge_tts.Communicate(text, voice)
await communicate.save(mp3_path)
log.info(f"语音文件生成成功: {mp3_path}")
except Exception as e:
raise RuntimeError(f"生成语音文件失败: {e}") from e
return mp3_path

View File

@@ -0,0 +1,299 @@
#!/usr/bin/env python3
"""系统操作和环境相关工具函数"""
import asyncio
import copy
import hashlib
import logging
import os
import platform
import random
import string
import urllib.parse
from http.cookies import SimpleCookie
from urllib.parse import urlparse
import aiohttp
from requests.utils import cookiejar_from_dict
log = logging.getLogger(__package__)
def parse_cookie_string(cookie_string: str):
"""
解析 Cookie 字符串
Args:
cookie_string: Cookie 字符串
Returns:
CookieJar 对象
"""
cookie = SimpleCookie()
cookie.load(cookie_string)
cookies_dict = {k: m.value for k, m in cookie.items()}
return cookiejar_from_dict(cookies_dict, cookiejar=None, overwrite=True)
def validate_proxy(proxy_str: str) -> bool:
"""
验证代理字符串格式
Args:
proxy_str: 代理字符串
Returns:
True 如果格式正确
Raises:
ValueError: 如果格式不正确
"""
parsed = urlparse(proxy_str)
if parsed.scheme not in ("http", "https"):
raise ValueError("Proxy scheme must be http or https")
if not (parsed.hostname and parsed.port):
raise ValueError("Proxy hostname and port must be set")
return True
def get_random(length: int) -> str:
"""
生成随机字符串
Args:
length: 字符串长度
Returns:
随机字符串
"""
return "".join(random.sample(string.ascii_letters + string.digits, length))
def deepcopy_data_no_sensitive_info(data, fields_to_anonymize: list = None):
"""
深拷贝数据并脱敏
Args:
data: 要拷贝的数据(字典或对象)
fields_to_anonymize: 需要脱敏的字段列表
Returns:
脱敏后的深拷贝数据
"""
if fields_to_anonymize is None:
fields_to_anonymize = [
"account",
"password",
"httpauth_username",
"httpauth_password",
]
copy_data = copy.deepcopy(data)
# 检查copy_data是否是字典或具有属性的对象
if isinstance(copy_data, dict):
# 对字典进行处理
for field in fields_to_anonymize:
if field in copy_data:
copy_data[field] = "******"
else:
# 对对象进行处理
for field in fields_to_anonymize:
if hasattr(copy_data, field):
setattr(copy_data, field, "******")
return copy_data
def try_add_access_control_param(config, url: str) -> str:
"""
为 URL 添加访问控制参数
Args:
config: 配置对象
url: 原始 URL
Returns:
添加了访问控制参数的 URL
"""
if config.disable_httpauth:
return url
url_parts = urllib.parse.urlparse(url)
file_path = urllib.parse.unquote(url_parts.path)
correct_code = hashlib.sha256(
(file_path + config.httpauth_username + config.httpauth_password).encode(
"utf-8"
)
).hexdigest()
log.debug(f"rewrite url: [{file_path}, {correct_code}]")
# make new url
parsed_get_args = dict(urllib.parse.parse_qsl(url_parts.query))
parsed_get_args.update({"code": correct_code})
encoded_get_args = urllib.parse.urlencode(parsed_get_args, doseq=True)
new_url = urllib.parse.ParseResult(
url_parts.scheme,
url_parts.netloc,
url_parts.path,
url_parts.params,
encoded_get_args,
url_parts.fragment,
).geturl()
return new_url
def is_docker() -> bool:
"""判断是否在 Docker 容器中运行"""
return os.path.exists("/app/.dockerenv")
def get_os_architecture() -> str:
"""
获取操作系统架构类型amd64、arm64、arm-v7
Returns:
str: 架构类型
"""
arch = platform.machine().lower()
if arch in ("x86_64", "amd64"):
return "amd64"
elif arch in ("aarch64", "arm64"):
return "arm64"
elif "arm" in arch or "armv7" in arch:
return "arm-v7"
else:
return f"unknown architecture: {arch}"
async def get_latest_version(package_name: str) -> str:
"""
从 PyPI 获取包的最新版本
Args:
package_name: 包名
Returns:
最新版本号,失败返回 None
"""
url = f"https://pypi.org/pypi/{package_name}/json"
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
if response.status == 200:
data = await response.json()
return data["info"]["version"]
else:
return None
async def restart_xiaomusic() -> int:
"""
重启 xiaomusic 程序
Returns:
退出码
"""
# 重启 xiaomusic 程序
sbp_args = (
"supervisorctl",
"restart",
"xiaomusic",
)
cmd = " ".join(sbp_args)
log.info(f"restart_xiaomusic: {cmd}")
await asyncio.sleep(2)
proc = await asyncio.create_subprocess_exec(*sbp_args)
exit_code = await proc.wait() # 等待子进程完成
log.info(f"restart_xiaomusic completed with exit code {exit_code}")
return exit_code
async def update_version(version: str, lite: bool = True) -> str:
"""
更新 xiaomusic 版本
Args:
version: 版本号
lite: 是否使用 lite 版本
Returns:
结果消息
"""
if not is_docker():
ret = "xiaomusic 更新只能在 docker 中进行"
log.info(ret)
return ret
lite_tag = ""
if lite:
lite_tag = "-lite"
arch = get_os_architecture()
if "unknown" in arch:
log.warning(f"update_version failed: {arch}")
return arch
# https://github.com/hanxi/xiaomusic/releases/download/main/app-amd64-lite.tar.gz
url = f"https://gproxy.hanxi.cc/proxy/hanxi/xiaomusic/releases/download/{version}/app-{arch}{lite_tag}.tar.gz"
target_directory = "/app"
return await download_and_extract(url, target_directory)
async def download_and_extract(url: str, target_directory: str) -> str:
"""
下载并解压文件
Args:
url: 下载 URL
target_directory: 目标目录
Returns:
结果消息
"""
ret = "OK"
# 创建目标目录
os.makedirs(target_directory, exist_ok=True)
# 使用 aiohttp 异步下载文件
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
if response.status == 200:
file_name = os.path.join(target_directory, url.split("/")[-1])
file_name = os.path.normpath(file_name)
if not file_name.startswith(target_directory):
log.warning(f"Invalid file path: {file_name}")
return "Invalid file path"
with open(file_name, "wb") as f:
# 以块的方式下载文件,防止内存占用过大
async for chunk in response.content.iter_any():
f.write(chunk)
log.info(f"文件下载完成: {file_name}")
# 解压下载的文件
if file_name.endswith(".tar.gz"):
await extract_tar_gz(file_name, target_directory)
else:
ret = f"下载失败, 包有问题: {file_name}"
log.warning(ret)
else:
ret = f"下载失败, 状态码: {response.status}"
log.warning(ret)
return ret
async def extract_tar_gz(file_name: str, target_directory: str) -> None:
"""
解压 tar.gz 文件
Args:
file_name: 文件路径
target_directory: 目标目录
"""
# 使用 asyncio.create_subprocess_exec 执行 tar 解压命令
command = ["tar", "-xzvf", file_name, "-C", target_directory]
# 启动子进程执行解压命令
await asyncio.create_subprocess_exec(*command)
# 不等待子进程完成
log.info(f"extract_tar_gz ing {file_name}")

View File

@@ -0,0 +1,264 @@
#!/usr/bin/env python3
"""文本处理和搜索相关工具函数"""
import difflib
import re
from collections.abc import AsyncIterator
from opencc import OpenCC
# 繁简转换器
cc = OpenCC("t2s")
# TTS 相关正则
_no_elapse_chars = re.compile(r"([「」『』《》" "'\"()]|(?<!-)-(?!-))", re.UNICODE)
_ending_punctuations = ("", "", "", "", ".", "?", "!", ";")
# 中文数字映射
chinese_to_arabic = {
"": 0,
"": 1,
"": 2,
"": 3,
"": 4,
"": 5,
"": 6,
"": 7,
"": 8,
"": 9,
"": 10,
"": 100,
"": 1000,
"": 10000,
"亿": 100000000,
}
def calculate_tts_elapse(text: str) -> float:
"""计算 TTS 语音时长"""
# for simplicity, we use a fixed speed
speed = 4.5 # this value is picked by trial and error
# Exclude quotes and brackets that do not affect the total elapsed time
return len(_no_elapse_chars.sub("", text)) / speed
async def split_sentences(text_stream: AsyncIterator[str]) -> AsyncIterator[str]:
"""分句处理,按标点符号分割"""
cur = ""
async for text in text_stream:
cur += text
if cur.endswith(_ending_punctuations):
yield cur
cur = ""
if cur:
yield cur
def find_key_by_partial_string(dictionary: dict[str, str], partial_key: str) -> str:
"""通过部分字符串查找字典中的键"""
for key, value in dictionary.items():
if key in partial_key:
return value
return None
def traditional_to_simple(to_convert: str) -> str:
"""繁体转简体"""
return cc.convert(to_convert)
def keyword_detection(user_input: str, str_list: list, n: int) -> tuple[list, list]:
"""
关键词检测
Args:
user_input: 用户输入
str_list: 候选字符串列表
n: 返回匹配数量,-1 表示返回所有
Returns:
(匹配列表, 剩余列表)
"""
# 过滤包含关键字的字符串
matched, remains = [], []
for item in str_list:
if user_input in item:
matched.append(item)
else:
remains.append(item)
matched = sorted(
matched,
key=lambda s: difflib.SequenceMatcher(None, s, user_input).ratio(),
reverse=True, # 降序排序,越相似的越靠前
)
# 如果 n 是 -1如果 n 大于匹配的数量,返回所有匹配的结果
if n == -1 or n > len(matched):
return matched, remains
# 选择前 n 个匹配的结果
remains = matched[n:] + remains
return matched[:n], remains
def real_search(prompt: str, candidates: list, cutoff: float, n: int) -> list:
"""实际搜索逻辑"""
matches, remains = keyword_detection(prompt, candidates, n=n)
if len(matches) < n:
# 如果没有准确关键词匹配,开始模糊匹配
matches += difflib.get_close_matches(prompt, remains, n=n, cutoff=cutoff)
return matches
def find_best_match(
user_input: str,
collection: list,
cutoff: float = 0.6,
n: int = 1,
extra_search_index: dict = None,
) -> list:
"""
查找最佳匹配
Args:
user_input: 用户输入
collection: 候选集合
cutoff: 相似度阈值
n: 返回数量
extra_search_index: 额外搜索索引
Returns:
匹配结果列表
"""
lower_collection = {
traditional_to_simple(item.lower()): item for item in collection
}
user_input = traditional_to_simple(user_input.lower())
matches = real_search(user_input, list(lower_collection.keys()), cutoff, n)
cur_matched_collection = [lower_collection[match] for match in matches]
if len(matches) >= n or extra_search_index is None:
return cur_matched_collection[:n]
# 如果数量不满足,继续搜索
lower_extra_search_index = {
traditional_to_simple(k.lower()): v
for k, v in extra_search_index.items()
if v not in cur_matched_collection
}
matches = real_search(user_input, list(lower_extra_search_index.keys()), cutoff, n)
cur_matched_collection += [lower_extra_search_index[match] for match in matches]
return cur_matched_collection[:n]
def fuzzyfinder(
user_input: str, collection: list, extra_search_index: dict = None
) -> list:
"""模糊搜索"""
return find_best_match(
user_input, collection, cutoff=0.1, n=10, extra_search_index=extra_search_index
)
def custom_sort_key(s: str) -> tuple:
"""
歌曲排序键函数
支持数字前缀、数字后缀和字典序排序
"""
# 使用正则表达式分别提取字符串的数字前缀和数字后缀
prefix_match = re.match(r"^(\d+)", s)
suffix_match = re.search(r"(\d+)$", s)
numeric_prefix = int(prefix_match.group(0)) if prefix_match else None
numeric_suffix = int(suffix_match.group(0)) if suffix_match else None
if numeric_prefix is not None:
# 如果前缀是数字,先按前缀数字排序,再按整个字符串排序
return (0, numeric_prefix, s)
elif numeric_suffix is not None:
# 如果后缀是数字,先按前缀字符排序,再按后缀数字排序
return (1, s[: suffix_match.start()], numeric_suffix)
else:
# 如果前缀和后缀都不是数字,按字典序排序
return (2, s)
def chinese_to_number(chinese: str) -> int:
"""
中文数字转阿拉伯数字
Args:
chinese: 中文数字字符串,如 "一百二十三"
Returns:
对应的阿拉伯数字
"""
result = 0
unit = 1
num = 0
# 处理特殊情况:以"十"开头时,在前面加"一"
if chinese.startswith(""):
chinese = "" + chinese
# 如果只有一个字符且是单位,直接返回其值
if len(chinese) == 1 and chinese_to_arabic[chinese] >= 10:
return chinese_to_arabic[chinese]
for char in reversed(chinese):
if char in chinese_to_arabic:
val = chinese_to_arabic[char]
if val >= 10:
if val > unit:
unit = val
else:
unit *= val
else:
num += val * unit
result += num
return result
def parse_str_to_dict(s: str, d1: str = ",", d2: str = ":") -> dict:
"""
解析字符串为字典
格式: k1:v1,k2:v2
Args:
s: 待解析字符串
d1: 第一级分隔符(默认逗号)
d2: 第二级分隔符(默认冒号)
Returns:
解析后的字典
"""
result = {}
parts = s.split(d1)
for part in parts:
# 根据冒号切割
subparts = part.split(d2)
if len(subparts) == 2: # 防止数据不是成对出现
k, v = subparts
result[k] = v
return result
def list2str(li: list, verbose: bool = False) -> str:
"""
列表转字符串展示
Args:
li: 列表
verbose: 是否详细显示
Returns:
格式化的字符串
"""
if len(li) > 5 and not verbose:
return f"{li[:2]} ... {li[-2:]} with len: {len(li)}"
else:
return f"{li}"