This commit is contained in:
yadonglu
2024-12-05 11:28:18 -08:00
parent 2f13aebc6e
commit 075f349ea1
5 changed files with 100 additions and 161 deletions

View File

@@ -3,25 +3,24 @@ from ultralytics.nn.tasks import DetectionModel
from safetensors.torch import load_file
import argparse
import yaml
import os
# accept args to specify v1 or v1_5
parser = argparse.ArgumentParser(description='Specify version v1 or v1_5')
parser.add_argument('--version', choices=['v1', 'v1_5'], required=True, help='Specify the version: v1 or v1_5')
parser.add_argument('--weights_dir', type=str, required=True, help='Specify the path to the safetensor file', default='weights/icon_detect_v1_5')
args = parser.parse_args()
if args.version == 'v1':
tensor_dict = load_file("weights/icon_detect/model.safetensors")
model = DetectionModel('weights/icon_detect/model.yaml')
model.load_state_dict(tensor_dict)
torch.save({'model':model}, 'weights/icon_detect/best.pt')
elif args.version == 'v1_5':
print("Converting v1_5")
tensor_dict = load_file("weights/icon_detect_v1_5/model.safetensors")
model = DetectionModel('weights/icon_detect_v1_5/model.yaml')
model.load_state_dict(tensor_dict)
save_dict = {'model':model}
tensor_dict = load_file(os.path.join(args.weights_dir, "model.safetensors"))
model = DetectionModel(os.path.join(args.weights_dir, "model.yaml"))
# from ultralytics import YOLO
# som_model = YOLO("yolo11m.pt")
# model = som_model.model
model.load_state_dict(tensor_dict)
save_dict = {'model':model}
with open(os.path.join(args.weights_dir, "train_args.yaml"), 'r') as file:
train_args = yaml.safe_load(file)
save_dict.update(train_args)
torch.save(save_dict, os.path.join(args.weights_dir, "best.pt"))
with open("weights/icon_detect_v1_5/train_args.yaml", 'r') as file:
train_args = yaml.safe_load(file)
save_dict.update(train_args)
torch.save(save_dict, 'weights/icon_detect_v1_5/best.pt')