Files
test_nine/train_paddle.py
2024-11-02 02:22:29 +08:00

46 lines
1.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
from paddlex.utils.result_saver import try_except_decorator
from paddlex.utils.config import parse_args, get_config
from paddlex.utils.errors import raise_unsupported_api_error
from paddlex.model import _ModelBasedConfig
print(f"""数据集格式如下:
dataset
├─images #所有图片存放路径
├─label.txt #标签路径,每一行数据格式为 <序号>+<空格>+<类别>如15 地球仪
├─train.txt #训练图片,每一行数据格式为 <图片路径>+<空格>+<类别>如images/001.jpg 0
└─验证集和测试集同上
""")
class Engine(object):
"""Engine"""
def __init__(self):
args = parse_args()
args.config='PP-HGNetV2-B4.yaml'
args.override=['Global.mode=train', 'Global.dataset_dir=dataset']
config = get_config(args.config, overrides=args.override, show=False)
self._mode = config.Global.mode
self._output = config.Global.output
self._model = _ModelBasedConfig(config)
@try_except_decorator
def run(self):
"""the main function"""
if self._mode == "check_dataset":
return self._model.check_dataset()
elif self._mode == "train":
self._model.train()
elif self._mode == "evaluate":
return self._model.evaluate()
elif self._mode == "export":
return self._model.export()
elif self._mode == "predict":
for res in self._model.predict():
res.print(json_format=False)
else:
raise_unsupported_api_error(f"{self._mode}", self.__class__)
if __name__ == "__main__":
Engine().run()