42 lines
1.5 KiB
Python
42 lines
1.5 KiB
Python
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import argparse
|
||
|
|
from pathlib import Path
|
||
|
|
|
||
|
|
from PIL import Image
|
||
|
|
|
||
|
|
|
||
|
|
def main() -> int:
|
||
|
|
parser = argparse.ArgumentParser(description="Run inference with a fine-tuned RF-DETR checkpoint.")
|
||
|
|
parser.add_argument("--weights", type=Path, required=True, help="Path to checkpoint_best_total.pth (or similar)")
|
||
|
|
parser.add_argument("--image", type=Path, required=True, help="Path to an image")
|
||
|
|
parser.add_argument("--threshold", type=float, default=0.5)
|
||
|
|
args = parser.parse_args()
|
||
|
|
|
||
|
|
if not args.weights.exists():
|
||
|
|
raise SystemExit(f"Weights not found: {args.weights}")
|
||
|
|
if not args.image.exists():
|
||
|
|
raise SystemExit(f"Image not found: {args.image}")
|
||
|
|
|
||
|
|
from rfdetr import RFDETRBase
|
||
|
|
|
||
|
|
model = RFDETRBase(pretrain_weights=str(args.weights))
|
||
|
|
|
||
|
|
image = Image.open(args.image).convert("RGB")
|
||
|
|
detections = model.predict(image, threshold=args.threshold)
|
||
|
|
|
||
|
|
# Print detections in a simple, script-friendly way.
|
||
|
|
# `detections` is a supervision.Detections object.
|
||
|
|
print(f"num_detections={len(detections)}")
|
||
|
|
for i in range(len(detections)):
|
||
|
|
xyxy = detections.xyxy[i].tolist()
|
||
|
|
conf = float(detections.confidence[i]) if detections.confidence is not None else None
|
||
|
|
cls = int(detections.class_id[i]) if detections.class_id is not None else None
|
||
|
|
print({"i": i, "class_id": cls, "confidence": conf, "xyxy": xyxy})
|
||
|
|
|
||
|
|
return 0
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
raise SystemExit(main())
|