mirror of
https://github.com/hanxi/xiaomusic.git
synced 2026-05-24 11:35:46 +08:00
refactor: 重构拆分 utils 文件
This commit is contained in:
1502
xiaomusic/utils.py
1502
xiaomusic/utils.py
File diff suppressed because it is too large
Load Diff
125
xiaomusic/utils/__init__.py
Normal file
125
xiaomusic/utils/__init__.py
Normal file
@@ -0,0 +1,125 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Utils package - 工具函数模块
|
||||
|
||||
将原 utils.py 拆分为多个职责清晰的子模块:
|
||||
- text_utils: 文本处理和搜索
|
||||
- file_utils: 文件和目录操作
|
||||
- music_utils: 音乐文件处理
|
||||
- network_utils: 网络请求和下载
|
||||
- system_utils: 系统操作和环境
|
||||
"""
|
||||
|
||||
# 从各子模块导入常用函数,保持向后兼容
|
||||
from xiaomusic.utils.file_utils import (
|
||||
chmoddir,
|
||||
chmodfile,
|
||||
not_in_dirs,
|
||||
remove_common_prefix,
|
||||
safe_join_path,
|
||||
traverse_music_directory,
|
||||
)
|
||||
from xiaomusic.utils.music_utils import (
|
||||
Metadata,
|
||||
convert_file_to_mp3,
|
||||
extract_audio_metadata,
|
||||
get_duration_by_ffprobe,
|
||||
get_duration_by_mutagen,
|
||||
get_local_music_duration,
|
||||
get_web_music_duration,
|
||||
is_m4a,
|
||||
is_mp3,
|
||||
remove_id3_tags,
|
||||
save_picture_by_base64,
|
||||
set_music_tag_to_file,
|
||||
)
|
||||
from xiaomusic.utils.network_utils import (
|
||||
MusicUrlCache,
|
||||
check_bili_fav_list,
|
||||
download_one_music,
|
||||
download_playlist,
|
||||
downloadfile,
|
||||
fetch_json_get,
|
||||
text_to_mp3,
|
||||
)
|
||||
from xiaomusic.utils.system_utils import (
|
||||
deepcopy_data_no_sensitive_info,
|
||||
download_and_extract,
|
||||
get_latest_version,
|
||||
get_os_architecture,
|
||||
get_random,
|
||||
is_docker,
|
||||
parse_cookie_string,
|
||||
restart_xiaomusic,
|
||||
try_add_access_control_param,
|
||||
update_version,
|
||||
validate_proxy,
|
||||
)
|
||||
from xiaomusic.utils.text_utils import (
|
||||
calculate_tts_elapse,
|
||||
chinese_to_number,
|
||||
custom_sort_key,
|
||||
find_best_match,
|
||||
find_key_by_partial_string,
|
||||
fuzzyfinder,
|
||||
keyword_detection,
|
||||
list2str,
|
||||
parse_str_to_dict,
|
||||
split_sentences,
|
||||
traditional_to_simple,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# text_utils
|
||||
"calculate_tts_elapse",
|
||||
"chinese_to_number",
|
||||
"custom_sort_key",
|
||||
"find_best_match",
|
||||
"find_key_by_partial_string",
|
||||
"fuzzyfinder",
|
||||
"keyword_detection",
|
||||
"list2str",
|
||||
"parse_str_to_dict",
|
||||
"split_sentences",
|
||||
"traditional_to_simple",
|
||||
# file_utils
|
||||
"chmoddir",
|
||||
"chmodfile",
|
||||
"not_in_dirs",
|
||||
"remove_common_prefix",
|
||||
"safe_join_path",
|
||||
"traverse_music_directory",
|
||||
# music_utils
|
||||
"Metadata",
|
||||
"convert_file_to_mp3",
|
||||
"extract_audio_metadata",
|
||||
"get_duration_by_ffprobe",
|
||||
"get_duration_by_mutagen",
|
||||
"get_local_music_duration",
|
||||
"get_web_music_duration",
|
||||
"is_m4a",
|
||||
"is_mp3",
|
||||
"remove_id3_tags",
|
||||
"set_music_tag_to_file",
|
||||
"save_picture_by_base64",
|
||||
# network_utils
|
||||
"MusicUrlCache",
|
||||
"check_bili_fav_list",
|
||||
"download_one_music",
|
||||
"download_playlist",
|
||||
"downloadfile",
|
||||
"fetch_json_get",
|
||||
"text_to_mp3",
|
||||
# system_utils
|
||||
"deepcopy_data_no_sensitive_info",
|
||||
"download_and_extract",
|
||||
"get_latest_version",
|
||||
"get_os_architecture",
|
||||
"get_random",
|
||||
"is_docker",
|
||||
"parse_cookie_string",
|
||||
"restart_xiaomusic",
|
||||
"try_add_access_control_param",
|
||||
"update_version",
|
||||
"validate_proxy",
|
||||
]
|
||||
189
xiaomusic/utils/file_utils.py
Normal file
189
xiaomusic/utils/file_utils.py
Normal file
@@ -0,0 +1,189 @@
|
||||
#!/usr/bin/env python3
|
||||
"""文件和目录操作相关工具函数"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
|
||||
log = logging.getLogger(__package__)
|
||||
|
||||
|
||||
def _get_depth_path(root: str, directory: str, depth: int) -> str:
|
||||
"""计算指定深度的路径"""
|
||||
# 计算当前目录的深度
|
||||
relative_path = root[len(directory) :].strip(os.sep)
|
||||
path_parts = relative_path.split(os.sep)
|
||||
if len(path_parts) >= depth:
|
||||
return os.path.join(directory, *path_parts[:depth])
|
||||
else:
|
||||
return root
|
||||
|
||||
|
||||
def _append_files_result(
|
||||
result: dict, root: str, joinpath: str, files: list, support_extension: set
|
||||
) -> None:
|
||||
"""将文件添加到结果字典中"""
|
||||
dir_name = os.path.basename(root)
|
||||
if dir_name not in result:
|
||||
result[dir_name] = []
|
||||
for file in files:
|
||||
# 过滤隐藏文件
|
||||
if file.startswith("."):
|
||||
continue
|
||||
# 过滤文件后缀
|
||||
(name, extension) = os.path.splitext(file)
|
||||
if extension.lower() not in support_extension:
|
||||
continue
|
||||
|
||||
result[dir_name].append(os.path.join(joinpath, file))
|
||||
|
||||
|
||||
def traverse_music_directory(
|
||||
directory: str, depth: int, exclude_dirs: set, support_extension: set
|
||||
) -> dict:
|
||||
"""
|
||||
遍历音乐目录
|
||||
|
||||
Args:
|
||||
directory: 目录路径
|
||||
depth: 遍历深度
|
||||
exclude_dirs: 排除的目录集合
|
||||
support_extension: 支持的文件扩展名集合
|
||||
|
||||
Returns:
|
||||
{目录名: [文件路径列表]}
|
||||
"""
|
||||
result = {}
|
||||
for root, dirs, files in os.walk(directory, followlinks=True):
|
||||
# 忽略排除的目录
|
||||
dirs[:] = [d for d in dirs if d not in exclude_dirs]
|
||||
|
||||
# 计算当前目录的深度
|
||||
current_depth = root[len(directory) :].count(os.sep) + 1
|
||||
if current_depth > depth:
|
||||
depth_path = _get_depth_path(root, directory, depth - 1)
|
||||
_append_files_result(result, depth_path, root, files, support_extension)
|
||||
else:
|
||||
_append_files_result(result, root, root, files, support_extension)
|
||||
return result
|
||||
|
||||
|
||||
def safe_join_path(safe_root: str, directory: str) -> str:
|
||||
"""
|
||||
安全地拼接路径,确保结果在安全根目录内
|
||||
|
||||
Args:
|
||||
safe_root: 安全根目录
|
||||
directory: 要拼接的目录
|
||||
|
||||
Returns:
|
||||
规范化的完整路径
|
||||
|
||||
Raises:
|
||||
ValueError: 如果路径不在安全根目录内
|
||||
"""
|
||||
directory = os.path.join(safe_root, directory)
|
||||
# Normalize the directory path
|
||||
normalized_directory = os.path.normpath(directory)
|
||||
# Ensure the directory is within the safe root
|
||||
if not normalized_directory.startswith(os.path.normpath(safe_root)):
|
||||
raise ValueError(f"Access to directory '{directory}' is not allowed.")
|
||||
return normalized_directory
|
||||
|
||||
|
||||
def _longest_common_prefix(file_names: list) -> str:
|
||||
"""查找文件名列表的最长公共前缀"""
|
||||
if not file_names:
|
||||
return ""
|
||||
|
||||
# 将第一个文件名作为初始前缀
|
||||
prefix = file_names[0]
|
||||
|
||||
for file_name in file_names[1:]:
|
||||
while not file_name.startswith(prefix):
|
||||
# 如果当前文件名不以prefix开头,则缩短prefix
|
||||
prefix = prefix[:-1]
|
||||
if not prefix:
|
||||
return ""
|
||||
|
||||
return prefix
|
||||
|
||||
|
||||
def remove_common_prefix(directory: str) -> None:
|
||||
"""
|
||||
移除目录下文件名的公共前缀
|
||||
|
||||
Args:
|
||||
directory: 目录路径
|
||||
"""
|
||||
files = os.listdir(directory)
|
||||
|
||||
# 获取所有文件的前缀
|
||||
common_prefix = _longest_common_prefix(files)
|
||||
|
||||
log.info(f'Common prefix identified: "{common_prefix}"')
|
||||
|
||||
pattern = re.compile(r"^[pP]?(\d+)\s+\d*(.+?)\.(.*$)")
|
||||
for filename in files:
|
||||
if filename == common_prefix:
|
||||
continue
|
||||
# 检查文件名是否以共同前缀开头
|
||||
if filename.startswith(common_prefix):
|
||||
# 构造新的文件名
|
||||
new_filename = filename[len(common_prefix) :]
|
||||
match = pattern.search(new_filename.strip())
|
||||
if match:
|
||||
num = match.group(1)
|
||||
name = match.group(2).replace(".", " ").strip()
|
||||
suffix = match.group(3)
|
||||
new_filename = f"{num}.{name}.{suffix}"
|
||||
# 生成完整的文件路径
|
||||
old_file_path = os.path.join(directory, filename)
|
||||
new_file_path = os.path.join(directory, new_filename)
|
||||
|
||||
# 重命名文件
|
||||
os.rename(old_file_path, new_file_path)
|
||||
log.debug(f'Renamed: "{filename}" to "{new_filename}"')
|
||||
|
||||
|
||||
def not_in_dirs(filename: str, ignore_absolute_dirs: list) -> bool:
|
||||
"""
|
||||
判断文件是否不在排除目录列表中
|
||||
|
||||
Args:
|
||||
filename: 文件路径
|
||||
ignore_absolute_dirs: 要忽略的绝对路径列表
|
||||
|
||||
Returns:
|
||||
True 如果文件不在排除目录中
|
||||
"""
|
||||
file_absolute_path = os.path.abspath(filename)
|
||||
file_dir = os.path.dirname(file_absolute_path)
|
||||
for ignore_dir in ignore_absolute_dirs:
|
||||
if file_dir.startswith(ignore_dir):
|
||||
log.info(f"{file_dir} in {ignore_dir}")
|
||||
return False # 文件在排除目录中
|
||||
|
||||
return True # 文件不在排除目录中
|
||||
|
||||
|
||||
def chmodfile(file_path: str) -> None:
|
||||
"""修改文件权限为 775"""
|
||||
try:
|
||||
os.chmod(file_path, 0o775)
|
||||
except Exception as e:
|
||||
log.info(f"chmodfile failed: {e}")
|
||||
|
||||
|
||||
def chmoddir(dir_path: str) -> None:
|
||||
"""修改目录下所有文件的权限为 775"""
|
||||
# 获取指定目录下的所有文件和子目录
|
||||
for item in os.listdir(dir_path):
|
||||
item_path = os.path.join(dir_path, item)
|
||||
# 确保是文件,且不是目录
|
||||
if os.path.isfile(item_path):
|
||||
try:
|
||||
os.chmod(item_path, 0o775)
|
||||
log.info(f"Changed permissions of file: {item_path}")
|
||||
except Exception as e:
|
||||
log.info(f"chmoddir failed: {e}")
|
||||
695
xiaomusic/utils/music_utils.py
Normal file
695
xiaomusic/utils/music_utils.py
Normal file
@@ -0,0 +1,695 @@
|
||||
#!/usr/bin/env python3
|
||||
"""音乐文件处理相关工具函数"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import mimetypes
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
from dataclasses import asdict, dataclass
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import aiohttp
|
||||
import mutagen
|
||||
from mutagen.asf import ASF
|
||||
from mutagen.flac import FLAC
|
||||
from mutagen.id3 import (
|
||||
APIC,
|
||||
ID3,
|
||||
TALB,
|
||||
TCON,
|
||||
TDRC,
|
||||
TIT2,
|
||||
TPE1,
|
||||
USLT,
|
||||
Encoding,
|
||||
TextFrame,
|
||||
TimeStampTextFrame,
|
||||
)
|
||||
from mutagen.mp3 import MP3
|
||||
from mutagen.mp4 import MP4
|
||||
from mutagen.oggvorbis import OggVorbis
|
||||
from mutagen.wave import WAVE
|
||||
from mutagen.wavpack import WavPack
|
||||
from PIL import Image
|
||||
|
||||
from xiaomusic.const import SUPPORT_MUSIC_TYPE
|
||||
|
||||
log = logging.getLogger(__package__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Metadata:
|
||||
"""音乐元数据"""
|
||||
|
||||
title: str = ""
|
||||
artist: str = ""
|
||||
album: str = ""
|
||||
year: str = ""
|
||||
genre: str = ""
|
||||
picture: str = ""
|
||||
lyrics: str = ""
|
||||
|
||||
def __init__(self, info=None):
|
||||
if info:
|
||||
self.title = info.get("title", "")
|
||||
self.artist = info.get("artist", "")
|
||||
self.album = info.get("album", "")
|
||||
self.year = info.get("year", "")
|
||||
self.genre = info.get("genre", "")
|
||||
self.picture = info.get("picture", "")
|
||||
self.lyrics = info.get("lyrics", "")
|
||||
|
||||
|
||||
def is_mp3(url: str) -> bool:
|
||||
"""判断是否为 MP3 文件"""
|
||||
mt = mimetypes.guess_type(url)
|
||||
if mt and mt[0] == "audio/mpeg":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_m4a(url: str) -> bool:
|
||||
"""判断是否为 M4A 文件"""
|
||||
return url.endswith(".m4a")
|
||||
|
||||
|
||||
async def _get_web_music_duration(
|
||||
session, url: str, config, start: int = 0, end: int = 500
|
||||
) -> float:
|
||||
"""
|
||||
异步获取网络音乐文件的部分内容并估算其时长
|
||||
|
||||
通过请求 URL 的前几个字节(默认 0-500)下载部分文件,
|
||||
写入临时文件后调用本地工具(如 ffprobe)获取音频时长
|
||||
|
||||
Args:
|
||||
session: aiohttp.ClientSession 实例
|
||||
url: 音乐文件的 URL 地址
|
||||
config: 包含配置信息的对象(如 ffmpeg 路径)
|
||||
start: 请求的起始字节位置
|
||||
end: 请求的结束字节位置
|
||||
|
||||
Returns:
|
||||
返回音频的持续时间(秒),如果失败则返回 0
|
||||
"""
|
||||
duration = 0
|
||||
# 设置请求头 Range,只请求部分内容(用于快速获取元数据)
|
||||
headers = {"Range": f"bytes={start}-{end}"}
|
||||
|
||||
# 使用 aiohttp 异步发起 GET 请求,获取部分音频内容
|
||||
async with session.get(url, headers=headers) as response:
|
||||
array_buffer = await response.read() # 读取响应的二进制内容
|
||||
|
||||
# 创建一个命名的临时文件,并禁用自动删除(以便后续读取)
|
||||
with tempfile.NamedTemporaryFile(delete=False) as tmp:
|
||||
tmp.write(array_buffer) # 将下载的部分内容写入临时文件
|
||||
tmp_path = tmp.name # 获取该临时文件的真实路径
|
||||
|
||||
try:
|
||||
# 调用 get_local_music_duration 并传入文件路径,而不是文件对象
|
||||
duration = await get_local_music_duration(tmp_path, config)
|
||||
except Exception as e:
|
||||
log.error(f"Error _get_web_music_duration: {e}")
|
||||
finally:
|
||||
# 手动删除临时文件,避免残留
|
||||
os.unlink(tmp_path)
|
||||
|
||||
return duration
|
||||
|
||||
|
||||
async def get_web_music_duration(url: str, config) -> tuple[float, str]:
|
||||
"""
|
||||
获取网络音乐时长
|
||||
|
||||
Args:
|
||||
url: 音乐 URL
|
||||
config: 配置对象
|
||||
|
||||
Returns:
|
||||
(时长(秒), 最终URL)
|
||||
"""
|
||||
duration = 0
|
||||
try:
|
||||
parsed_url = urlparse(url)
|
||||
file_path = parsed_url.path
|
||||
_, extension = os.path.splitext(file_path)
|
||||
if extension.lower() not in SUPPORT_MUSIC_TYPE:
|
||||
cleaned_url = parsed_url.geturl()
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
cleaned_url,
|
||||
allow_redirects=True,
|
||||
headers={
|
||||
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_10_1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/39.0.2171.95 Safari/537.36"
|
||||
},
|
||||
) as response:
|
||||
url = str(response.url)
|
||||
# 设置总超时时间为3秒
|
||||
timeout = aiohttp.ClientTimeout(total=3)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
duration = await _get_web_music_duration(
|
||||
session, url, config, start=0, end=500
|
||||
)
|
||||
if duration <= 0:
|
||||
duration = await _get_web_music_duration(
|
||||
session, url, config, start=0, end=3000
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Error get_web_music_duration: {e}")
|
||||
return duration, url
|
||||
|
||||
|
||||
async def get_local_music_duration(filename: str, config) -> float:
|
||||
"""
|
||||
获取本地音乐文件播放时长
|
||||
|
||||
Args:
|
||||
filename: 文件路径
|
||||
config: 配置对象
|
||||
|
||||
Returns:
|
||||
时长(秒)
|
||||
"""
|
||||
duration = 0
|
||||
if config.get_duration_type == "ffprobe":
|
||||
duration = get_duration_by_ffprobe(filename, config.ffmpeg_location)
|
||||
else:
|
||||
duration = await get_duration_by_mutagen(filename)
|
||||
|
||||
# 换个方式重试一次
|
||||
if duration == 0:
|
||||
if config.get_duration_type != "ffprobe":
|
||||
duration = get_duration_by_ffprobe(filename, config.ffmpeg_location)
|
||||
else:
|
||||
duration = await get_duration_by_mutagen(filename)
|
||||
|
||||
return duration
|
||||
|
||||
|
||||
async def get_duration_by_mutagen(file_path: str) -> float:
|
||||
"""使用 mutagen 获取音乐时长"""
|
||||
duration = 0
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if is_mp3(file_path):
|
||||
m = await loop.run_in_executor(None, mutagen.mp3.MP3, file_path)
|
||||
else:
|
||||
m = await loop.run_in_executor(None, mutagen.File, file_path)
|
||||
duration = m.info.length
|
||||
except Exception as e:
|
||||
log.warning(f"Error getting local music {file_path} duration: {e}")
|
||||
return duration
|
||||
|
||||
|
||||
def get_duration_by_ffprobe(file_path: str, ffmpeg_location: str) -> float:
|
||||
"""使用 ffprobe 获取音乐时长"""
|
||||
duration = 0
|
||||
try:
|
||||
# 构造 ffprobe 命令参数
|
||||
cmd_args = [
|
||||
os.path.join(ffmpeg_location, "ffprobe"),
|
||||
"-v",
|
||||
"error", # 只输出错误信息,避免混杂在其他输出中
|
||||
"-show_entries",
|
||||
"format=duration", # 仅显示时长
|
||||
"-of",
|
||||
"json", # 以 JSON 格式输出
|
||||
file_path,
|
||||
]
|
||||
|
||||
# 输出待执行的完整命令
|
||||
full_command = " ".join(cmd_args)
|
||||
log.info(f"待执行的完整命令 ffprobe command: {full_command}")
|
||||
|
||||
# 使用 ffprobe 获取文件的元数据,并以 JSON 格式输出
|
||||
result = subprocess.run(
|
||||
cmd_args,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
)
|
||||
|
||||
# 输出命令执行结果
|
||||
log.info(
|
||||
f"命令执行结果 command result - return code: {result.returncode}, stdout: {result.stdout}"
|
||||
)
|
||||
|
||||
# 解析 JSON 输出
|
||||
ffprobe_output = json.loads(result.stdout)
|
||||
|
||||
# 获取时长
|
||||
duration = float(ffprobe_output["format"]["duration"])
|
||||
log.info(
|
||||
f"Successfully extracted duration: {duration} seconds for file: {file_path}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
log.warning(f"Error getting local music {file_path} duration: {e}")
|
||||
return duration
|
||||
|
||||
|
||||
def no_padding(info) -> int:
|
||||
"""移除 MP3 文件的 padding"""
|
||||
# this will remove all padding
|
||||
return 0
|
||||
|
||||
|
||||
def remove_id3_tags(input_file: str, config) -> str:
|
||||
"""
|
||||
移除 MP3 文件的 ID3 标签以减少延迟
|
||||
|
||||
Args:
|
||||
input_file: 输入文件路径
|
||||
config: 配置对象
|
||||
|
||||
Returns:
|
||||
处理后的相对路径,如果无需处理则返回 None
|
||||
"""
|
||||
audio = MP3(input_file, ID3=ID3)
|
||||
|
||||
# 检查是否存在ID3 v2.3或v2.4标签
|
||||
if not (
|
||||
audio.tags
|
||||
and (audio.tags.version == (2, 3, 0) or audio.tags.version == (2, 4, 0))
|
||||
):
|
||||
return None
|
||||
|
||||
music_path = config.music_path
|
||||
temp_dir = config.temp_dir
|
||||
|
||||
# 构造新文件的路径
|
||||
out_file_name = os.path.splitext(os.path.basename(input_file))[0]
|
||||
out_file_path = os.path.join(temp_dir, f"{out_file_name}.mp3")
|
||||
relative_path = os.path.relpath(out_file_path, music_path)
|
||||
|
||||
# 路径相同的情况
|
||||
input_absolute_path = os.path.abspath(input_file)
|
||||
output_absolute_path = os.path.abspath(out_file_path)
|
||||
if input_absolute_path == output_absolute_path:
|
||||
log.info(f"File {input_file} = {out_file_path} . Skipping remove_id3_tags.")
|
||||
return None
|
||||
|
||||
# 检查目标文件是否存在
|
||||
if os.path.exists(out_file_path):
|
||||
log.info(f"File {out_file_path} already exists. Skipping remove_id3_tags.")
|
||||
return relative_path
|
||||
|
||||
# 开始去除(不再需要检查)
|
||||
# 拷贝文件
|
||||
shutil.copy(input_file, out_file_path)
|
||||
outaudio = MP3(out_file_path, ID3=ID3)
|
||||
# 删除ID3标签
|
||||
outaudio.delete()
|
||||
# 保存修改后的文件
|
||||
outaudio.save(padding=no_padding)
|
||||
log.info(f"File {out_file_path} remove_id3_tags ok.")
|
||||
return relative_path
|
||||
|
||||
|
||||
def convert_file_to_mp3(input_file: str, config) -> str:
|
||||
"""
|
||||
转换音频文件为 MP3 格式
|
||||
|
||||
Args:
|
||||
input_file: 输入文件路径
|
||||
config: 配置对象
|
||||
|
||||
Returns:
|
||||
转换后的相对路径,如果无需转换则返回 None
|
||||
"""
|
||||
music_path = config.music_path
|
||||
temp_dir = config.temp_dir
|
||||
|
||||
out_file_name = os.path.splitext(os.path.basename(input_file))[0]
|
||||
out_file_path = os.path.join(temp_dir, f"{out_file_name}.mp3")
|
||||
relative_path = os.path.relpath(out_file_path, music_path)
|
||||
|
||||
# 路径相同的情况
|
||||
input_absolute_path = os.path.abspath(input_file)
|
||||
output_absolute_path = os.path.abspath(out_file_path)
|
||||
if input_absolute_path == output_absolute_path:
|
||||
log.info(f"File {input_file} = {out_file_path} . Skipping convert_file_to_mp3.")
|
||||
return None
|
||||
|
||||
absolute_music_path = os.path.abspath(music_path)
|
||||
if not input_absolute_path.startswith(absolute_music_path):
|
||||
log.error(f"Invalid input file path: {input_file}")
|
||||
return None
|
||||
|
||||
# 检查目标文件是否存在
|
||||
if os.path.exists(out_file_path):
|
||||
log.info(f"File {out_file_path} already exists. Skipping convert_file_to_mp3.")
|
||||
return relative_path
|
||||
|
||||
# 检查是否存在 loudnorm 参数
|
||||
loudnorm_args = []
|
||||
if config.loudnorm:
|
||||
loudnorm_args = ["-af", config.loudnorm]
|
||||
|
||||
command = [
|
||||
os.path.join(config.ffmpeg_location, "ffmpeg"),
|
||||
"-i",
|
||||
input_absolute_path,
|
||||
"-f",
|
||||
"mp3",
|
||||
"-vn",
|
||||
"-y",
|
||||
*loudnorm_args,
|
||||
out_file_path,
|
||||
]
|
||||
|
||||
try:
|
||||
subprocess.run(command, check=True)
|
||||
except subprocess.CalledProcessError as e:
|
||||
log.exception(f"Error during conversion: {e}")
|
||||
return None
|
||||
|
||||
log.info(f"File {input_file} to {out_file_path} convert_file_to_mp3 ok.")
|
||||
return relative_path
|
||||
|
||||
|
||||
def _to_utf8(v):
|
||||
"""转换标签值为 UTF-8 字符串"""
|
||||
if isinstance(v, TextFrame) and not isinstance(v, TimeStampTextFrame):
|
||||
old_ts = "".join(v.text)
|
||||
if v.encoding == Encoding.LATIN1:
|
||||
bs = old_ts.encode("latin1")
|
||||
ts = bs.decode("GBK", errors="ignore")
|
||||
return ts
|
||||
return old_ts
|
||||
elif isinstance(v, list):
|
||||
return "".join(str(item) for item in v)
|
||||
return str(v)
|
||||
|
||||
|
||||
def _get_tag_value(tags, k: str) -> str:
|
||||
"""获取标签值"""
|
||||
if k not in tags:
|
||||
return ""
|
||||
v = tags[k]
|
||||
return _to_utf8(v)
|
||||
|
||||
|
||||
def _get_alltag_value(tags, k: str) -> str:
|
||||
"""获取所有标签值"""
|
||||
v = tags.getall(k)
|
||||
if len(v) > 0:
|
||||
return _to_utf8(v[0])
|
||||
return ""
|
||||
|
||||
|
||||
def _save_picture(picture_data: bytes, save_root: str, file_path: str) -> str:
|
||||
"""保存图片"""
|
||||
# 计算文件名的哈希值
|
||||
file_hash = hashlib.md5(file_path.encode("utf-8")).hexdigest()
|
||||
# 创建目录结构
|
||||
dir_path = os.path.join(save_root, file_hash[-6:])
|
||||
os.makedirs(dir_path, exist_ok=True)
|
||||
|
||||
# 保存图片
|
||||
filename = os.path.basename(file_path)
|
||||
(name, _) = os.path.splitext(filename)
|
||||
picture_path = os.path.join(dir_path, f"{name}.jpg")
|
||||
|
||||
try:
|
||||
_resize_save_image(picture_data, picture_path)
|
||||
except Exception as e:
|
||||
log.warning(f"Error _resize_save_image: {e}")
|
||||
return picture_path
|
||||
|
||||
|
||||
def _resize_save_image(image_bytes: bytes, save_path: str, max_size: int = 300) -> str:
|
||||
"""缩放并保存图片"""
|
||||
# 将 bytes 转换为 PIL Image 对象
|
||||
image = None
|
||||
try:
|
||||
image = Image.open(io.BytesIO(image_bytes))
|
||||
image = image.convert("RGB")
|
||||
except Exception as e:
|
||||
log.warning(f"Error _resize_save_image: {e}")
|
||||
return None
|
||||
|
||||
# 获取原始尺寸
|
||||
original_width, original_height = image.size
|
||||
|
||||
# 如果图片的宽度和高度都小于 max_size,则直接保存原始图片
|
||||
if original_width <= max_size and original_height <= max_size:
|
||||
image.save(save_path, format="JPEG")
|
||||
return save_path
|
||||
|
||||
# 计算缩放比例,保持等比缩放
|
||||
scaling_factor = min(max_size / original_width, max_size / original_height)
|
||||
|
||||
# 计算新的尺寸
|
||||
new_width = int(original_width * scaling_factor)
|
||||
new_height = int(original_height * scaling_factor)
|
||||
|
||||
resized_image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
||||
resized_image.save(save_path, format="JPEG")
|
||||
return save_path
|
||||
|
||||
|
||||
def save_picture_by_base64(
|
||||
picture_base64_data: str, save_root: str, file_path: str
|
||||
) -> str:
|
||||
"""通过 base64 数据保存图片"""
|
||||
try:
|
||||
picture_data = base64.b64decode(picture_base64_data)
|
||||
except (TypeError, ValueError) as e:
|
||||
log.exception(f"Error decoding base64 data: {e}")
|
||||
return None
|
||||
return _save_picture(picture_data, save_root, file_path)
|
||||
|
||||
|
||||
def extract_audio_metadata(file_path: str, save_root: str) -> dict:
|
||||
"""
|
||||
提取音频文件的元数据
|
||||
|
||||
Args:
|
||||
file_path: 音频文件路径
|
||||
save_root: 图片保存根目录
|
||||
|
||||
Returns:
|
||||
元数据字典
|
||||
"""
|
||||
metadata = Metadata()
|
||||
|
||||
audio = None
|
||||
try:
|
||||
audio = mutagen.File(file_path)
|
||||
except Exception as e:
|
||||
log.warning(f"Error extract_audio_metadata file: {file_path} {e}")
|
||||
|
||||
if audio is None:
|
||||
return asdict(metadata)
|
||||
|
||||
tags = audio.tags
|
||||
if tags is None:
|
||||
return asdict(metadata)
|
||||
|
||||
if isinstance(audio, MP3):
|
||||
metadata.title = _get_tag_value(tags, "TIT2")
|
||||
metadata.artist = _get_tag_value(tags, "TPE1")
|
||||
metadata.album = _get_tag_value(tags, "TALB")
|
||||
metadata.year = _get_tag_value(tags, "TDRC")
|
||||
metadata.genre = _get_tag_value(tags, "TCON")
|
||||
metadata.lyrics = _get_alltag_value(tags, "USLT")
|
||||
for tag in tags.values():
|
||||
if isinstance(tag, APIC):
|
||||
metadata.picture = _save_picture(tag.data, save_root, file_path)
|
||||
break
|
||||
|
||||
elif isinstance(audio, FLAC):
|
||||
metadata.title = _get_tag_value(tags, "TITLE")
|
||||
metadata.artist = _get_tag_value(tags, "ARTIST")
|
||||
metadata.album = _get_tag_value(tags, "ALBUM")
|
||||
metadata.year = _get_tag_value(tags, "DATE")
|
||||
metadata.genre = _get_tag_value(tags, "GENRE")
|
||||
if audio.pictures:
|
||||
metadata.picture = _save_picture(
|
||||
audio.pictures[0].data, save_root, file_path
|
||||
)
|
||||
if "lyrics" in audio:
|
||||
metadata.lyrics = audio["lyrics"][0]
|
||||
|
||||
elif isinstance(audio, MP4):
|
||||
metadata.title = _get_tag_value(tags, "\xa9nam")
|
||||
metadata.artist = _get_tag_value(tags, "\xa9ART")
|
||||
metadata.album = _get_tag_value(tags, "\xa9alb")
|
||||
metadata.year = _get_tag_value(tags, "\xa9day")
|
||||
metadata.genre = _get_tag_value(tags, "\xa9gen")
|
||||
if "covr" in tags:
|
||||
metadata.picture = _save_picture(tags["covr"][0], save_root, file_path)
|
||||
|
||||
elif isinstance(audio, OggVorbis):
|
||||
metadata.title = _get_tag_value(tags, "TITLE")
|
||||
metadata.artist = _get_tag_value(tags, "ARTIST")
|
||||
metadata.album = _get_tag_value(tags, "ALBUM")
|
||||
metadata.year = _get_tag_value(tags, "DATE")
|
||||
metadata.genre = _get_tag_value(tags, "GENRE")
|
||||
if "metadata_block_picture" in tags:
|
||||
picture = json.loads(base64.b64decode(tags["metadata_block_picture"][0]))
|
||||
metadata.picture = _save_picture(
|
||||
base64.b64decode(picture["data"]), save_root, file_path
|
||||
)
|
||||
|
||||
elif isinstance(audio, ASF):
|
||||
metadata.title = _get_tag_value(tags, "Title")
|
||||
metadata.artist = _get_tag_value(tags, "Author")
|
||||
metadata.album = _get_tag_value(tags, "WM/AlbumTitle")
|
||||
metadata.year = _get_tag_value(tags, "WM/Year")
|
||||
metadata.genre = _get_tag_value(tags, "WM/Genre")
|
||||
if "WM/Picture" in tags:
|
||||
metadata.picture = _save_picture(
|
||||
tags["WM/Picture"][0].value, save_root, file_path
|
||||
)
|
||||
|
||||
elif isinstance(audio, WavPack):
|
||||
metadata.title = _get_tag_value(tags, "Title")
|
||||
metadata.artist = _get_tag_value(tags, "Artist")
|
||||
metadata.album = _get_tag_value(tags, "Album")
|
||||
metadata.year = _get_tag_value(tags, "Year")
|
||||
metadata.genre = _get_tag_value(tags, "Genre")
|
||||
if audio.pictures:
|
||||
metadata.picture = _save_picture(
|
||||
audio.pictures[0].data, save_root, file_path
|
||||
)
|
||||
|
||||
elif isinstance(audio, WAVE):
|
||||
metadata.title = _get_tag_value(tags, "Title")
|
||||
metadata.artist = _get_tag_value(tags, "Artist")
|
||||
|
||||
return asdict(metadata)
|
||||
|
||||
|
||||
def _set_mp3_tags(audio, info: Metadata) -> None:
|
||||
"""设置 MP3 标签"""
|
||||
audio.tags = ID3()
|
||||
audio["TIT2"] = TIT2(encoding=3, text=info.title)
|
||||
audio["TPE1"] = TPE1(encoding=3, text=info.artist)
|
||||
audio["TALB"] = TALB(encoding=3, text=info.album)
|
||||
audio["TDRC"] = TDRC(encoding=3, text=info.year)
|
||||
audio["TCON"] = TCON(encoding=3, text=info.genre)
|
||||
|
||||
# 使用 USLT 存储歌词
|
||||
if info.lyrics:
|
||||
audio["USLT"] = USLT(encoding=3, lang="eng", text=info.lyrics)
|
||||
|
||||
# 添加封面图片
|
||||
if info.picture:
|
||||
with open(info.picture, "rb") as img_file:
|
||||
image_data = img_file.read()
|
||||
audio["APIC"] = APIC(
|
||||
encoding=3, mime="image/jpeg", type=3, desc="Cover", data=image_data
|
||||
)
|
||||
audio.save() # 保存修改
|
||||
|
||||
|
||||
def _set_flac_tags(audio, info: Metadata) -> None:
|
||||
"""设置 FLAC 标签"""
|
||||
audio["TITLE"] = info.title
|
||||
audio["ARTIST"] = info.artist
|
||||
audio["ALBUM"] = info.album
|
||||
audio["DATE"] = info.year
|
||||
audio["GENRE"] = info.genre
|
||||
if info.lyrics:
|
||||
audio["LYRICS"] = info.lyrics
|
||||
if info.picture:
|
||||
with open(info.picture, "rb") as img_file:
|
||||
image_data = img_file.read()
|
||||
audio.add_picture(image_data)
|
||||
|
||||
|
||||
def _set_mp4_tags(audio, info: Metadata) -> None:
|
||||
"""设置 MP4 标签"""
|
||||
audio["nam"] = info.title
|
||||
audio["ART"] = info.artist
|
||||
audio["alb"] = info.album
|
||||
audio["day"] = info.year
|
||||
audio["gen"] = info.genre
|
||||
if info.picture:
|
||||
with open(info.picture, "rb") as img_file:
|
||||
image_data = img_file.read()
|
||||
audio["covr"] = [image_data]
|
||||
|
||||
|
||||
def _set_ogg_tags(audio, info: Metadata) -> None:
|
||||
"""设置 OGG 标签"""
|
||||
audio["TITLE"] = info.title
|
||||
audio["ARTIST"] = info.artist
|
||||
audio["ALBUM"] = info.album
|
||||
audio["DATE"] = info.year
|
||||
audio["GENRE"] = info.genre
|
||||
if info.lyrics:
|
||||
audio["LYRICS"] = info.lyrics
|
||||
if info.picture:
|
||||
with open(info.picture, "rb") as img_file:
|
||||
image_data = img_file.read()
|
||||
audio["metadata_block_picture"] = base64.b64encode(image_data).decode()
|
||||
|
||||
|
||||
def _set_asf_tags(audio, info: Metadata) -> None:
|
||||
"""设置 ASF 标签"""
|
||||
audio["Title"] = info.title
|
||||
audio["Author"] = info.artist
|
||||
audio["WM/AlbumTitle"] = info.album
|
||||
audio["WM/Year"] = info.year
|
||||
audio["WM/Genre"] = info.genre
|
||||
if info.picture:
|
||||
with open(info.picture, "rb") as img_file:
|
||||
image_data = img_file.read()
|
||||
audio["WM/Picture"] = image_data
|
||||
|
||||
|
||||
def _set_wave_tags(audio, info: Metadata) -> None:
|
||||
"""设置 WAVE 标签"""
|
||||
audio["Title"] = info.title
|
||||
audio["Artist"] = info.artist
|
||||
|
||||
|
||||
def set_music_tag_to_file(file_path: str, info: Metadata) -> str:
|
||||
"""
|
||||
设置音乐文件的标签信息
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
info: 元数据对象
|
||||
|
||||
Returns:
|
||||
"OK" 或错误信息
|
||||
"""
|
||||
audio = mutagen.File(file_path, easy=True)
|
||||
if audio is None:
|
||||
log.error(f"Unable to open file {file_path}")
|
||||
return "Unable to open file"
|
||||
|
||||
if isinstance(audio, MP3):
|
||||
_set_mp3_tags(audio, info)
|
||||
elif isinstance(audio, FLAC):
|
||||
_set_flac_tags(audio, info)
|
||||
elif isinstance(audio, MP4):
|
||||
_set_mp4_tags(audio, info)
|
||||
elif isinstance(audio, OggVorbis):
|
||||
_set_ogg_tags(audio, info)
|
||||
elif isinstance(audio, ASF):
|
||||
_set_asf_tags(audio, info)
|
||||
elif isinstance(audio, WAVE):
|
||||
_set_wave_tags(audio, info)
|
||||
else:
|
||||
log.error(f"Unsupported file type for {file_path}")
|
||||
return "Unsupported file type"
|
||||
|
||||
try:
|
||||
audio.save()
|
||||
log.info(f"Tags saved successfully to {file_path}")
|
||||
return "OK"
|
||||
except Exception as e:
|
||||
log.exception(f"Error saving tags: {e}")
|
||||
return "Error saving tags"
|
||||
448
xiaomusic/utils/network_utils.py
Normal file
448
xiaomusic/utils/network_utils.py
Normal file
@@ -0,0 +1,448 @@
|
||||
#!/usr/bin/env python3
|
||||
"""网络请求和下载相关工具函数"""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
from time import sleep
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
import aiohttp
|
||||
import edge_tts
|
||||
|
||||
log = logging.getLogger(__package__)
|
||||
|
||||
|
||||
async def downloadfile(url: str) -> str:
|
||||
"""
|
||||
下载文件内容
|
||||
|
||||
Args:
|
||||
url: 文件 URL
|
||||
|
||||
Returns:
|
||||
文件文本内容
|
||||
|
||||
Raises:
|
||||
Warning: 如果 URL 协议不是 HTTP/HTTPS
|
||||
"""
|
||||
# 清理和验证URL
|
||||
# 解析URL
|
||||
parsed_url = urlparse(url)
|
||||
# 基础验证:仅允许HTTP和HTTPS协议
|
||||
if parsed_url.scheme not in ("http", "https"):
|
||||
raise Warning(
|
||||
f"Invalid URL scheme: {parsed_url.scheme}. Only HTTP and HTTPS are allowed."
|
||||
)
|
||||
# 构建目标URL
|
||||
cleaned_url = parsed_url.geturl()
|
||||
|
||||
# 使用 aiohttp 创建一个客户端会话来发起请求
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
cleaned_url, timeout=5
|
||||
) as response: # 增加超时以避免长时间挂起
|
||||
# 如果响应不是200,引发异常
|
||||
response.raise_for_status()
|
||||
# 读取响应文本
|
||||
text = await response.text()
|
||||
return text
|
||||
|
||||
|
||||
async def check_bili_fav_list(url: str) -> dict:
|
||||
"""
|
||||
检查 B 站收藏夹/合集
|
||||
|
||||
Args:
|
||||
url: B站收藏夹或合集 URL
|
||||
|
||||
Returns:
|
||||
{bvid/url: title} 字典
|
||||
|
||||
Raises:
|
||||
ValueError: 如果不支持的类型
|
||||
Exception: 如果请求失败
|
||||
"""
|
||||
bvid_info = {}
|
||||
parsed_url = urlparse(url)
|
||||
path = parsed_url.path
|
||||
# 提取查询参数
|
||||
query_params = parse_qs(parsed_url.query)
|
||||
|
||||
if parsed_url.hostname == "space.bilibili.com":
|
||||
if "/favlist" in path:
|
||||
lid = query_params.get("fid", [None])[0]
|
||||
type = query_params.get("ctype", [None])[0]
|
||||
if type == "11":
|
||||
type = "create"
|
||||
elif type == "21":
|
||||
type = "collect"
|
||||
else:
|
||||
raise ValueError("当前只支持合集和收藏夹")
|
||||
elif "/lists/" in path:
|
||||
parts = path.split("/")
|
||||
if len(parts) >= 4 and "?" in url:
|
||||
lid = parts[3] # 提取 lid
|
||||
type = query_params.get("type", [None])[0]
|
||||
|
||||
# https://api.bilibili.com/x/polymer/web-space/seasons_archives_list?season_id={lid}&page_size=30&page_num=1
|
||||
page_size = 100
|
||||
page_num = 1
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0 Safari/537.36",
|
||||
"Accept": "application/json, text/plain, */*",
|
||||
"Referer": url,
|
||||
"Origin": "https://space.bilibili.com",
|
||||
}
|
||||
async with aiohttp.ClientSession(headers=headers) as session:
|
||||
if type == "season" or type == "collect":
|
||||
while True:
|
||||
list_url = f"https://api.bilibili.com/x/polymer/web-space/seasons_archives_list?season_id={lid}&page_size={page_size}&page_num={page_num}"
|
||||
async with session.get(list_url) as response:
|
||||
if response.status != 200:
|
||||
raise Exception(f"Failed to fetch data from {list_url}")
|
||||
data = await response.json()
|
||||
archives = data.get("data", {}).get("archives", [])
|
||||
if not archives:
|
||||
break
|
||||
for archive in archives:
|
||||
bvid = archive.get("bvid", None)
|
||||
title = archive.get("title", None)
|
||||
bvid_info[bvid] = title
|
||||
|
||||
if len(archives) < page_size:
|
||||
break
|
||||
page_num += 1
|
||||
sleep(1)
|
||||
elif type == "create":
|
||||
while True:
|
||||
list_url = f"https://api.bilibili.com/x/v3/fav/resource/list?media_id={lid}&pn={page_num}&ps={page_size}&order=mtime"
|
||||
async with session.get(list_url) as response:
|
||||
if response.status != 200:
|
||||
raise Exception(f"Failed to fetch data from {list_url}")
|
||||
data = await response.json()
|
||||
medias = data.get("data", {}).get("medias", [])
|
||||
if not medias:
|
||||
break
|
||||
for media in medias:
|
||||
bvid = media.get("bvid", None)
|
||||
title = media.get("title", None)
|
||||
bvurl = f"https://www.bilibili.com/video/{bvid}"
|
||||
bvid_info[bvurl] = title
|
||||
|
||||
if len(medias) < page_size:
|
||||
break
|
||||
page_num += 1
|
||||
else:
|
||||
raise ValueError("当前只支持合集和收藏夹")
|
||||
return bvid_info
|
||||
|
||||
|
||||
async def download_playlist(config, url: str, dirname: str):
|
||||
"""
|
||||
下载播放列表
|
||||
|
||||
Args:
|
||||
config: 配置对象
|
||||
url: 播放列表 URL
|
||||
dirname: 保存目录名
|
||||
|
||||
Returns:
|
||||
下载进程对象
|
||||
"""
|
||||
title = f"{dirname}/%(title)s.%(ext)s"
|
||||
sbp_args = (
|
||||
"yt-dlp",
|
||||
"--yes-playlist",
|
||||
"-x",
|
||||
"--audio-format",
|
||||
"mp3",
|
||||
"--audio-quality",
|
||||
"0",
|
||||
"--paths",
|
||||
config.download_path,
|
||||
"-o",
|
||||
title,
|
||||
"--ffmpeg-location",
|
||||
f"{config.ffmpeg_location}",
|
||||
)
|
||||
|
||||
if config.proxy:
|
||||
sbp_args += ("--proxy", f"{config.proxy}")
|
||||
|
||||
if config.enable_yt_dlp_cookies:
|
||||
sbp_args += ("--cookies", f"{config.yt_dlp_cookies_path}")
|
||||
|
||||
if config.loudnorm:
|
||||
sbp_args += ("--postprocessor-args", f"-af {config.loudnorm}")
|
||||
|
||||
sbp_args += (url,)
|
||||
|
||||
cmd = " ".join(sbp_args)
|
||||
log.info(f"download_playlist: {cmd}")
|
||||
download_proc = await asyncio.create_subprocess_exec(*sbp_args)
|
||||
return download_proc
|
||||
|
||||
|
||||
async def download_one_music(config, url: str, name: str = ""):
|
||||
"""
|
||||
下载单首歌曲
|
||||
|
||||
Args:
|
||||
config: 配置对象
|
||||
url: 歌曲 URL
|
||||
name: 文件名(可选)
|
||||
|
||||
Returns:
|
||||
下载进程对象
|
||||
"""
|
||||
title = "%(title)s.%(ext)s"
|
||||
if name:
|
||||
title = f"{name}.%(ext)s"
|
||||
sbp_args = (
|
||||
"yt-dlp",
|
||||
"--no-playlist",
|
||||
"-x",
|
||||
"--audio-format",
|
||||
"mp3",
|
||||
"--audio-quality",
|
||||
"0",
|
||||
"--paths",
|
||||
config.download_path,
|
||||
"-o",
|
||||
title,
|
||||
"--ffmpeg-location",
|
||||
f"{config.ffmpeg_location}",
|
||||
)
|
||||
|
||||
if config.proxy:
|
||||
sbp_args += ("--proxy", f"{config.proxy}")
|
||||
|
||||
if config.enable_yt_dlp_cookies:
|
||||
sbp_args += ("--cookies", f"{config.yt_dlp_cookies_path}")
|
||||
|
||||
if config.loudnorm:
|
||||
sbp_args += ("--postprocessor-args", f"-af {config.loudnorm}")
|
||||
|
||||
sbp_args += (url,)
|
||||
|
||||
cmd = " ".join(sbp_args)
|
||||
log.info(f"download_one_music: {cmd}")
|
||||
download_proc = await asyncio.create_subprocess_exec(*sbp_args)
|
||||
return download_proc
|
||||
|
||||
|
||||
async def fetch_json_get(url: str, headers: dict, config) -> dict:
|
||||
"""
|
||||
发起 GET 请求获取 JSON 数据
|
||||
|
||||
Args:
|
||||
url: 请求 URL
|
||||
headers: 请求头
|
||||
config: 配置对象(用于代理设置)
|
||||
|
||||
Returns:
|
||||
JSON 响应数据字典
|
||||
"""
|
||||
connector = None
|
||||
proxy = None
|
||||
if config and config.proxy:
|
||||
connector = aiohttp.TCPConnector(
|
||||
ssl=False, # 如需验证SSL证书,可改为True(需确保代理支持)
|
||||
limit=10,
|
||||
)
|
||||
proxy = config.proxy
|
||||
try:
|
||||
# 2. 传入代理配置创建ClientSession
|
||||
async with aiohttp.ClientSession(connector=connector) as session:
|
||||
# 3. 发起带代理的GET请求
|
||||
async with session.get(
|
||||
url,
|
||||
headers=headers,
|
||||
proxy=proxy, # 传入格式化后的代理参数
|
||||
timeout=10, # 超时时间(秒),避免无限等待
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
log.info(f"fetch_json_get: {url} success {data}")
|
||||
|
||||
# 确保返回结果为dict
|
||||
if isinstance(data, dict):
|
||||
return data
|
||||
else:
|
||||
log.warning(f"Expected dict, but got {type(data)}: {data}")
|
||||
return {}
|
||||
else:
|
||||
log.error(f"HTTP Error: {response.status} {url}")
|
||||
return {}
|
||||
except aiohttp.ClientError as e:
|
||||
log.error(f"ClientError fetching {url} (proxy: {proxy}): {e}")
|
||||
return {}
|
||||
except asyncio.TimeoutError:
|
||||
log.error(f"Timeout fetching {url} (proxy: {proxy})")
|
||||
return {}
|
||||
except Exception as e:
|
||||
log.error(f"Unexpected error fetching {url} (proxy: {proxy}): {e}")
|
||||
return {}
|
||||
finally:
|
||||
# 4. 关闭连接器(避免资源泄漏)
|
||||
if connector and not connector.closed:
|
||||
await connector.close()
|
||||
|
||||
|
||||
class LRUCache(OrderedDict):
|
||||
"""LRU 缓存实现"""
|
||||
|
||||
def __init__(self, max_size: int = 1000):
|
||||
super().__init__()
|
||||
self.max_size = max_size
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
if key in self:
|
||||
# 移动到末尾(最近使用)
|
||||
self.move_to_end(key)
|
||||
super().__setitem__(key, value)
|
||||
# 如果超出大小限制,删除最早使用的项
|
||||
if len(self) > self.max_size:
|
||||
self.popitem(last=False)
|
||||
|
||||
def __getitem__(self, key):
|
||||
# 访问时移动到末尾(最近使用)
|
||||
if key in self:
|
||||
self.move_to_end(key)
|
||||
return super().__getitem__(key)
|
||||
|
||||
|
||||
class MusicUrlCache:
|
||||
"""音乐 URL 缓存管理器"""
|
||||
|
||||
def __init__(self, default_expire_days: int = 1, max_size: int = 1000):
|
||||
self.cache = LRUCache(max_size)
|
||||
self.default_expire_days = default_expire_days
|
||||
self.log = logging.getLogger(__name__)
|
||||
|
||||
async def get(self, url: str, headers: dict = None, config=None) -> str:
|
||||
"""
|
||||
获取URL(优先从缓存获取,没有则请求API)
|
||||
|
||||
Args:
|
||||
url: 原始URL
|
||||
headers: API请求需要的headers
|
||||
config: 配置对象
|
||||
|
||||
Returns:
|
||||
str: 真实播放URL
|
||||
"""
|
||||
# 先查询缓存
|
||||
cached_url = self._get_from_cache(url)
|
||||
if cached_url:
|
||||
self.log.info(f"Using cached url: {cached_url}")
|
||||
return cached_url
|
||||
|
||||
# 缓存未命中,请求API
|
||||
return await self._fetch_from_api(url, headers, config)
|
||||
|
||||
def _get_from_cache(self, url: str) -> str:
|
||||
"""从缓存中获取URL"""
|
||||
try:
|
||||
cached_url, expire_time = self.cache[url]
|
||||
if time.time() > expire_time:
|
||||
# 缓存过期,删除
|
||||
del self.cache[url]
|
||||
return ""
|
||||
return cached_url
|
||||
except KeyError:
|
||||
return ""
|
||||
|
||||
async def _fetch_from_api(self, url: str, headers: dict = None, config=None) -> str:
|
||||
"""从API获取真实URL"""
|
||||
data = await fetch_json_get(url, headers or {}, config)
|
||||
|
||||
if not isinstance(data, dict):
|
||||
self.log.error(f"Invalid API response format: {data}")
|
||||
return ""
|
||||
|
||||
real_url = data.get("url")
|
||||
if not real_url:
|
||||
self.log.error(f"No url in API response: {data}")
|
||||
return ""
|
||||
|
||||
# 获取过期时间
|
||||
expire_time = self._parse_expire_time(data)
|
||||
|
||||
# 缓存结果
|
||||
self._set_cache(url, real_url, expire_time)
|
||||
self.log.info(
|
||||
f"Cached url, expire_time: {expire_time}, cache size: {len(self.cache)}"
|
||||
)
|
||||
return real_url
|
||||
|
||||
def _parse_expire_time(self, data: dict) -> float | None:
|
||||
"""解析API返回的过期时间"""
|
||||
try:
|
||||
extra = data.get("extra", {})
|
||||
expire_info = extra.get("expire", {})
|
||||
if expire_info and expire_info.get("canExpire"):
|
||||
expire_time = expire_info.get("time")
|
||||
if expire_time:
|
||||
return float(expire_time)
|
||||
except Exception as e:
|
||||
self.log.warning(f"Failed to parse expire time: {e}")
|
||||
return None
|
||||
|
||||
def _set_cache(self, url: str, real_url: str, expire_time: float = None):
|
||||
"""设置缓存"""
|
||||
if expire_time is None:
|
||||
expire_time = time.time() + (self.default_expire_days * 24 * 3600)
|
||||
self.cache[url] = (real_url, expire_time)
|
||||
|
||||
def clear(self):
|
||||
"""清空缓存"""
|
||||
self.cache.clear()
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
"""当前缓存大小"""
|
||||
return len(self.cache)
|
||||
|
||||
|
||||
async def text_to_mp3(
|
||||
text: str, save_dir: str, voice: str = "zh-CN-XiaoxiaoNeural"
|
||||
) -> str:
|
||||
"""
|
||||
使用edge-tts将文本转换为MP3语音文件
|
||||
|
||||
参数:
|
||||
text: 需要转换的文本内容
|
||||
save_dir: 保存MP3文件的目录路径
|
||||
voice: 语音模型(默认中文晓晓)
|
||||
|
||||
返回:
|
||||
str: 生成的MP3文件完整路径
|
||||
"""
|
||||
# 确保保存目录存在
|
||||
Path(save_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 基于文本和语音模型生成唯一文件名(避免相同文本不同语音重复)
|
||||
content = f"{text}_{voice}".encode()
|
||||
file_hash = hashlib.md5(content).hexdigest()
|
||||
mp3_filename = f"{file_hash}.mp3"
|
||||
mp3_path = os.path.join(save_dir, mp3_filename)
|
||||
|
||||
# 文件已存在直接返回路径
|
||||
if os.path.exists(mp3_path):
|
||||
return mp3_path
|
||||
|
||||
# 调用edge-tts生成语音
|
||||
try:
|
||||
communicate = edge_tts.Communicate(text, voice)
|
||||
await communicate.save(mp3_path)
|
||||
log.info(f"语音文件生成成功: {mp3_path}")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"生成语音文件失败: {e}") from e
|
||||
|
||||
return mp3_path
|
||||
299
xiaomusic/utils/system_utils.py
Normal file
299
xiaomusic/utils/system_utils.py
Normal file
@@ -0,0 +1,299 @@
|
||||
#!/usr/bin/env python3
|
||||
"""系统操作和环境相关工具函数"""
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import random
|
||||
import string
|
||||
import urllib.parse
|
||||
from http.cookies import SimpleCookie
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import aiohttp
|
||||
from requests.utils import cookiejar_from_dict
|
||||
|
||||
log = logging.getLogger(__package__)
|
||||
|
||||
|
||||
def parse_cookie_string(cookie_string: str):
|
||||
"""
|
||||
解析 Cookie 字符串
|
||||
|
||||
Args:
|
||||
cookie_string: Cookie 字符串
|
||||
|
||||
Returns:
|
||||
CookieJar 对象
|
||||
"""
|
||||
cookie = SimpleCookie()
|
||||
cookie.load(cookie_string)
|
||||
cookies_dict = {k: m.value for k, m in cookie.items()}
|
||||
return cookiejar_from_dict(cookies_dict, cookiejar=None, overwrite=True)
|
||||
|
||||
|
||||
def validate_proxy(proxy_str: str) -> bool:
|
||||
"""
|
||||
验证代理字符串格式
|
||||
|
||||
Args:
|
||||
proxy_str: 代理字符串
|
||||
|
||||
Returns:
|
||||
True 如果格式正确
|
||||
|
||||
Raises:
|
||||
ValueError: 如果格式不正确
|
||||
"""
|
||||
parsed = urlparse(proxy_str)
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
raise ValueError("Proxy scheme must be http or https")
|
||||
if not (parsed.hostname and parsed.port):
|
||||
raise ValueError("Proxy hostname and port must be set")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def get_random(length: int) -> str:
|
||||
"""
|
||||
生成随机字符串
|
||||
|
||||
Args:
|
||||
length: 字符串长度
|
||||
|
||||
Returns:
|
||||
随机字符串
|
||||
"""
|
||||
return "".join(random.sample(string.ascii_letters + string.digits, length))
|
||||
|
||||
|
||||
def deepcopy_data_no_sensitive_info(data, fields_to_anonymize: list = None):
|
||||
"""
|
||||
深拷贝数据并脱敏
|
||||
|
||||
Args:
|
||||
data: 要拷贝的数据(字典或对象)
|
||||
fields_to_anonymize: 需要脱敏的字段列表
|
||||
|
||||
Returns:
|
||||
脱敏后的深拷贝数据
|
||||
"""
|
||||
if fields_to_anonymize is None:
|
||||
fields_to_anonymize = [
|
||||
"account",
|
||||
"password",
|
||||
"httpauth_username",
|
||||
"httpauth_password",
|
||||
]
|
||||
|
||||
copy_data = copy.deepcopy(data)
|
||||
|
||||
# 检查copy_data是否是字典或具有属性的对象
|
||||
if isinstance(copy_data, dict):
|
||||
# 对字典进行处理
|
||||
for field in fields_to_anonymize:
|
||||
if field in copy_data:
|
||||
copy_data[field] = "******"
|
||||
else:
|
||||
# 对对象进行处理
|
||||
for field in fields_to_anonymize:
|
||||
if hasattr(copy_data, field):
|
||||
setattr(copy_data, field, "******")
|
||||
|
||||
return copy_data
|
||||
|
||||
|
||||
def try_add_access_control_param(config, url: str) -> str:
|
||||
"""
|
||||
为 URL 添加访问控制参数
|
||||
|
||||
Args:
|
||||
config: 配置对象
|
||||
url: 原始 URL
|
||||
|
||||
Returns:
|
||||
添加了访问控制参数的 URL
|
||||
"""
|
||||
if config.disable_httpauth:
|
||||
return url
|
||||
|
||||
url_parts = urllib.parse.urlparse(url)
|
||||
file_path = urllib.parse.unquote(url_parts.path)
|
||||
correct_code = hashlib.sha256(
|
||||
(file_path + config.httpauth_username + config.httpauth_password).encode(
|
||||
"utf-8"
|
||||
)
|
||||
).hexdigest()
|
||||
log.debug(f"rewrite url: [{file_path}, {correct_code}]")
|
||||
|
||||
# make new url
|
||||
parsed_get_args = dict(urllib.parse.parse_qsl(url_parts.query))
|
||||
parsed_get_args.update({"code": correct_code})
|
||||
encoded_get_args = urllib.parse.urlencode(parsed_get_args, doseq=True)
|
||||
new_url = urllib.parse.ParseResult(
|
||||
url_parts.scheme,
|
||||
url_parts.netloc,
|
||||
url_parts.path,
|
||||
url_parts.params,
|
||||
encoded_get_args,
|
||||
url_parts.fragment,
|
||||
).geturl()
|
||||
|
||||
return new_url
|
||||
|
||||
|
||||
def is_docker() -> bool:
|
||||
"""判断是否在 Docker 容器中运行"""
|
||||
return os.path.exists("/app/.dockerenv")
|
||||
|
||||
|
||||
def get_os_architecture() -> str:
|
||||
"""
|
||||
获取操作系统架构类型:amd64、arm64、arm-v7
|
||||
|
||||
Returns:
|
||||
str: 架构类型
|
||||
"""
|
||||
arch = platform.machine().lower()
|
||||
|
||||
if arch in ("x86_64", "amd64"):
|
||||
return "amd64"
|
||||
elif arch in ("aarch64", "arm64"):
|
||||
return "arm64"
|
||||
elif "arm" in arch or "armv7" in arch:
|
||||
return "arm-v7"
|
||||
else:
|
||||
return f"unknown architecture: {arch}"
|
||||
|
||||
|
||||
async def get_latest_version(package_name: str) -> str:
|
||||
"""
|
||||
从 PyPI 获取包的最新版本
|
||||
|
||||
Args:
|
||||
package_name: 包名
|
||||
|
||||
Returns:
|
||||
最新版本号,失败返回 None
|
||||
"""
|
||||
url = f"https://pypi.org/pypi/{package_name}/json"
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url) as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
return data["info"]["version"]
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
async def restart_xiaomusic() -> int:
|
||||
"""
|
||||
重启 xiaomusic 程序
|
||||
|
||||
Returns:
|
||||
退出码
|
||||
"""
|
||||
# 重启 xiaomusic 程序
|
||||
sbp_args = (
|
||||
"supervisorctl",
|
||||
"restart",
|
||||
"xiaomusic",
|
||||
)
|
||||
|
||||
cmd = " ".join(sbp_args)
|
||||
log.info(f"restart_xiaomusic: {cmd}")
|
||||
await asyncio.sleep(2)
|
||||
proc = await asyncio.create_subprocess_exec(*sbp_args)
|
||||
exit_code = await proc.wait() # 等待子进程完成
|
||||
log.info(f"restart_xiaomusic completed with exit code {exit_code}")
|
||||
return exit_code
|
||||
|
||||
|
||||
async def update_version(version: str, lite: bool = True) -> str:
|
||||
"""
|
||||
更新 xiaomusic 版本
|
||||
|
||||
Args:
|
||||
version: 版本号
|
||||
lite: 是否使用 lite 版本
|
||||
|
||||
Returns:
|
||||
结果消息
|
||||
"""
|
||||
if not is_docker():
|
||||
ret = "xiaomusic 更新只能在 docker 中进行"
|
||||
log.info(ret)
|
||||
return ret
|
||||
lite_tag = ""
|
||||
if lite:
|
||||
lite_tag = "-lite"
|
||||
arch = get_os_architecture()
|
||||
if "unknown" in arch:
|
||||
log.warning(f"update_version failed: {arch}")
|
||||
return arch
|
||||
# https://github.com/hanxi/xiaomusic/releases/download/main/app-amd64-lite.tar.gz
|
||||
url = f"https://gproxy.hanxi.cc/proxy/hanxi/xiaomusic/releases/download/{version}/app-{arch}{lite_tag}.tar.gz"
|
||||
target_directory = "/app"
|
||||
return await download_and_extract(url, target_directory)
|
||||
|
||||
|
||||
async def download_and_extract(url: str, target_directory: str) -> str:
|
||||
"""
|
||||
下载并解压文件
|
||||
|
||||
Args:
|
||||
url: 下载 URL
|
||||
target_directory: 目标目录
|
||||
|
||||
Returns:
|
||||
结果消息
|
||||
"""
|
||||
ret = "OK"
|
||||
# 创建目标目录
|
||||
os.makedirs(target_directory, exist_ok=True)
|
||||
|
||||
# 使用 aiohttp 异步下载文件
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url) as response:
|
||||
if response.status == 200:
|
||||
file_name = os.path.join(target_directory, url.split("/")[-1])
|
||||
file_name = os.path.normpath(file_name)
|
||||
if not file_name.startswith(target_directory):
|
||||
log.warning(f"Invalid file path: {file_name}")
|
||||
return "Invalid file path"
|
||||
with open(file_name, "wb") as f:
|
||||
# 以块的方式下载文件,防止内存占用过大
|
||||
async for chunk in response.content.iter_any():
|
||||
f.write(chunk)
|
||||
log.info(f"文件下载完成: {file_name}")
|
||||
|
||||
# 解压下载的文件
|
||||
if file_name.endswith(".tar.gz"):
|
||||
await extract_tar_gz(file_name, target_directory)
|
||||
else:
|
||||
ret = f"下载失败, 包有问题: {file_name}"
|
||||
log.warning(ret)
|
||||
|
||||
else:
|
||||
ret = f"下载失败, 状态码: {response.status}"
|
||||
log.warning(ret)
|
||||
return ret
|
||||
|
||||
|
||||
async def extract_tar_gz(file_name: str, target_directory: str) -> None:
|
||||
"""
|
||||
解压 tar.gz 文件
|
||||
|
||||
Args:
|
||||
file_name: 文件路径
|
||||
target_directory: 目标目录
|
||||
"""
|
||||
# 使用 asyncio.create_subprocess_exec 执行 tar 解压命令
|
||||
command = ["tar", "-xzvf", file_name, "-C", target_directory]
|
||||
# 启动子进程执行解压命令
|
||||
await asyncio.create_subprocess_exec(*command)
|
||||
# 不等待子进程完成
|
||||
log.info(f"extract_tar_gz ing {file_name}")
|
||||
264
xiaomusic/utils/text_utils.py
Normal file
264
xiaomusic/utils/text_utils.py
Normal file
@@ -0,0 +1,264 @@
|
||||
#!/usr/bin/env python3
|
||||
"""文本处理和搜索相关工具函数"""
|
||||
|
||||
import difflib
|
||||
import re
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from opencc import OpenCC
|
||||
|
||||
# 繁简转换器
|
||||
cc = OpenCC("t2s")
|
||||
|
||||
# TTS 相关正则
|
||||
_no_elapse_chars = re.compile(r"([「」『』《》" "'\"()()]|(?<!-)-(?!-))", re.UNICODE)
|
||||
_ending_punctuations = ("。", "?", "!", ";", ".", "?", "!", ";")
|
||||
|
||||
# 中文数字映射
|
||||
chinese_to_arabic = {
|
||||
"零": 0,
|
||||
"一": 1,
|
||||
"二": 2,
|
||||
"三": 3,
|
||||
"四": 4,
|
||||
"五": 5,
|
||||
"六": 6,
|
||||
"七": 7,
|
||||
"八": 8,
|
||||
"九": 9,
|
||||
"十": 10,
|
||||
"百": 100,
|
||||
"千": 1000,
|
||||
"万": 10000,
|
||||
"亿": 100000000,
|
||||
}
|
||||
|
||||
|
||||
def calculate_tts_elapse(text: str) -> float:
|
||||
"""计算 TTS 语音时长"""
|
||||
# for simplicity, we use a fixed speed
|
||||
speed = 4.5 # this value is picked by trial and error
|
||||
# Exclude quotes and brackets that do not affect the total elapsed time
|
||||
return len(_no_elapse_chars.sub("", text)) / speed
|
||||
|
||||
|
||||
async def split_sentences(text_stream: AsyncIterator[str]) -> AsyncIterator[str]:
|
||||
"""分句处理,按标点符号分割"""
|
||||
cur = ""
|
||||
async for text in text_stream:
|
||||
cur += text
|
||||
if cur.endswith(_ending_punctuations):
|
||||
yield cur
|
||||
cur = ""
|
||||
if cur:
|
||||
yield cur
|
||||
|
||||
|
||||
def find_key_by_partial_string(dictionary: dict[str, str], partial_key: str) -> str:
|
||||
"""通过部分字符串查找字典中的键"""
|
||||
for key, value in dictionary.items():
|
||||
if key in partial_key:
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
def traditional_to_simple(to_convert: str) -> str:
|
||||
"""繁体转简体"""
|
||||
return cc.convert(to_convert)
|
||||
|
||||
|
||||
def keyword_detection(user_input: str, str_list: list, n: int) -> tuple[list, list]:
|
||||
"""
|
||||
关键词检测
|
||||
|
||||
Args:
|
||||
user_input: 用户输入
|
||||
str_list: 候选字符串列表
|
||||
n: 返回匹配数量,-1 表示返回所有
|
||||
|
||||
Returns:
|
||||
(匹配列表, 剩余列表)
|
||||
"""
|
||||
# 过滤包含关键字的字符串
|
||||
matched, remains = [], []
|
||||
for item in str_list:
|
||||
if user_input in item:
|
||||
matched.append(item)
|
||||
else:
|
||||
remains.append(item)
|
||||
|
||||
matched = sorted(
|
||||
matched,
|
||||
key=lambda s: difflib.SequenceMatcher(None, s, user_input).ratio(),
|
||||
reverse=True, # 降序排序,越相似的越靠前
|
||||
)
|
||||
|
||||
# 如果 n 是 -1,如果 n 大于匹配的数量,返回所有匹配的结果
|
||||
if n == -1 or n > len(matched):
|
||||
return matched, remains
|
||||
|
||||
# 选择前 n 个匹配的结果
|
||||
remains = matched[n:] + remains
|
||||
return matched[:n], remains
|
||||
|
||||
|
||||
def real_search(prompt: str, candidates: list, cutoff: float, n: int) -> list:
|
||||
"""实际搜索逻辑"""
|
||||
matches, remains = keyword_detection(prompt, candidates, n=n)
|
||||
if len(matches) < n:
|
||||
# 如果没有准确关键词匹配,开始模糊匹配
|
||||
matches += difflib.get_close_matches(prompt, remains, n=n, cutoff=cutoff)
|
||||
return matches
|
||||
|
||||
|
||||
def find_best_match(
|
||||
user_input: str,
|
||||
collection: list,
|
||||
cutoff: float = 0.6,
|
||||
n: int = 1,
|
||||
extra_search_index: dict = None,
|
||||
) -> list:
|
||||
"""
|
||||
查找最佳匹配
|
||||
|
||||
Args:
|
||||
user_input: 用户输入
|
||||
collection: 候选集合
|
||||
cutoff: 相似度阈值
|
||||
n: 返回数量
|
||||
extra_search_index: 额外搜索索引
|
||||
|
||||
Returns:
|
||||
匹配结果列表
|
||||
"""
|
||||
lower_collection = {
|
||||
traditional_to_simple(item.lower()): item for item in collection
|
||||
}
|
||||
user_input = traditional_to_simple(user_input.lower())
|
||||
matches = real_search(user_input, list(lower_collection.keys()), cutoff, n)
|
||||
cur_matched_collection = [lower_collection[match] for match in matches]
|
||||
if len(matches) >= n or extra_search_index is None:
|
||||
return cur_matched_collection[:n]
|
||||
|
||||
# 如果数量不满足,继续搜索
|
||||
lower_extra_search_index = {
|
||||
traditional_to_simple(k.lower()): v
|
||||
for k, v in extra_search_index.items()
|
||||
if v not in cur_matched_collection
|
||||
}
|
||||
matches = real_search(user_input, list(lower_extra_search_index.keys()), cutoff, n)
|
||||
cur_matched_collection += [lower_extra_search_index[match] for match in matches]
|
||||
return cur_matched_collection[:n]
|
||||
|
||||
|
||||
def fuzzyfinder(
|
||||
user_input: str, collection: list, extra_search_index: dict = None
|
||||
) -> list:
|
||||
"""模糊搜索"""
|
||||
return find_best_match(
|
||||
user_input, collection, cutoff=0.1, n=10, extra_search_index=extra_search_index
|
||||
)
|
||||
|
||||
|
||||
def custom_sort_key(s: str) -> tuple:
|
||||
"""
|
||||
歌曲排序键函数
|
||||
|
||||
支持数字前缀、数字后缀和字典序排序
|
||||
"""
|
||||
# 使用正则表达式分别提取字符串的数字前缀和数字后缀
|
||||
prefix_match = re.match(r"^(\d+)", s)
|
||||
suffix_match = re.search(r"(\d+)$", s)
|
||||
|
||||
numeric_prefix = int(prefix_match.group(0)) if prefix_match else None
|
||||
numeric_suffix = int(suffix_match.group(0)) if suffix_match else None
|
||||
|
||||
if numeric_prefix is not None:
|
||||
# 如果前缀是数字,先按前缀数字排序,再按整个字符串排序
|
||||
return (0, numeric_prefix, s)
|
||||
elif numeric_suffix is not None:
|
||||
# 如果后缀是数字,先按前缀字符排序,再按后缀数字排序
|
||||
return (1, s[: suffix_match.start()], numeric_suffix)
|
||||
else:
|
||||
# 如果前缀和后缀都不是数字,按字典序排序
|
||||
return (2, s)
|
||||
|
||||
|
||||
def chinese_to_number(chinese: str) -> int:
|
||||
"""
|
||||
中文数字转阿拉伯数字
|
||||
|
||||
Args:
|
||||
chinese: 中文数字字符串,如 "一百二十三"
|
||||
|
||||
Returns:
|
||||
对应的阿拉伯数字
|
||||
"""
|
||||
result = 0
|
||||
unit = 1
|
||||
num = 0
|
||||
# 处理特殊情况:以"十"开头时,在前面加"一"
|
||||
if chinese.startswith("十"):
|
||||
chinese = "一" + chinese
|
||||
|
||||
# 如果只有一个字符且是单位,直接返回其值
|
||||
if len(chinese) == 1 and chinese_to_arabic[chinese] >= 10:
|
||||
return chinese_to_arabic[chinese]
|
||||
|
||||
for char in reversed(chinese):
|
||||
if char in chinese_to_arabic:
|
||||
val = chinese_to_arabic[char]
|
||||
if val >= 10:
|
||||
if val > unit:
|
||||
unit = val
|
||||
else:
|
||||
unit *= val
|
||||
else:
|
||||
num += val * unit
|
||||
result += num
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def parse_str_to_dict(s: str, d1: str = ",", d2: str = ":") -> dict:
|
||||
"""
|
||||
解析字符串为字典
|
||||
|
||||
格式: k1:v1,k2:v2
|
||||
|
||||
Args:
|
||||
s: 待解析字符串
|
||||
d1: 第一级分隔符(默认逗号)
|
||||
d2: 第二级分隔符(默认冒号)
|
||||
|
||||
Returns:
|
||||
解析后的字典
|
||||
"""
|
||||
result = {}
|
||||
parts = s.split(d1)
|
||||
|
||||
for part in parts:
|
||||
# 根据冒号切割
|
||||
subparts = part.split(d2)
|
||||
if len(subparts) == 2: # 防止数据不是成对出现
|
||||
k, v = subparts
|
||||
result[k] = v
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def list2str(li: list, verbose: bool = False) -> str:
|
||||
"""
|
||||
列表转字符串展示
|
||||
|
||||
Args:
|
||||
li: 列表
|
||||
verbose: 是否详细显示
|
||||
|
||||
Returns:
|
||||
格式化的字符串
|
||||
"""
|
||||
if len(li) > 5 and not verbose:
|
||||
return f"{li[:2]} ... {li[-2:]} with len: {len(li)}"
|
||||
else:
|
||||
return f"{li}"
|
||||
Reference in New Issue
Block a user