Files
saw_mill_knot_detection/predict_rfdetr.py

42 lines
1.5 KiB
Python

from __future__ import annotations
import argparse
from pathlib import Path
from PIL import Image
def main() -> int:
parser = argparse.ArgumentParser(description="Run inference with a fine-tuned RF-DETR checkpoint.")
parser.add_argument("--weights", type=Path, required=True, help="Path to checkpoint_best_total.pth (or similar)")
parser.add_argument("--image", type=Path, required=True, help="Path to an image")
parser.add_argument("--threshold", type=float, default=0.5)
args = parser.parse_args()
if not args.weights.exists():
raise SystemExit(f"Weights not found: {args.weights}")
if not args.image.exists():
raise SystemExit(f"Image not found: {args.image}")
from rfdetr import RFDETRBase
model = RFDETRBase(pretrain_weights=str(args.weights))
image = Image.open(args.image).convert("RGB")
detections = model.predict(image, threshold=args.threshold)
# Print detections in a simple, script-friendly way.
# `detections` is a supervision.Detections object.
print(f"num_detections={len(detections)}")
for i in range(len(detections)):
xyxy = detections.xyxy[i].tolist()
conf = float(detections.confidence[i]) if detections.confidence is not None else None
cls = int(detections.class_id[i]) if detections.class_id is not None else None
print({"i": i, "class_id": cls, "confidence": conf, "xyxy": xyxy})
return 0
if __name__ == "__main__":
raise SystemExit(main())