""" Run inference with trained RT-DETR model. Usage: python predict_rtdetr.py --weights runs/rtdetr_training/training/weights/best.pt --image test.jpg """ import argparse from pathlib import Path import cv2 def predict(weights_path: Path, image_path: Path, threshold: float = 0.5, save: bool = True): """ Run RT-DETR inference on an image. Args: weights_path: Path to trained .pt weights image_path: Path to input image threshold: Detection confidence threshold save: Whether to save visualization """ from ultralytics import RTDETR if not weights_path.exists(): raise ValueError(f"Weights not found: {weights_path}") if not image_path.exists(): raise ValueError(f"Image not found: {image_path}") print(f"\n{'='*60}") print(f"RT-DETR Inference") print(f"Weights: {weights_path}") print(f"Image: {image_path}") print(f"Threshold: {threshold}") print(f"{'='*60}\n") # Load model model = RTDETR(weights_path) # Run inference results = model.predict( source=str(image_path), conf=threshold, save=save, show_labels=True, show_conf=True, ) # Print detections for result in results: boxes = result.boxes print(f"\nDetected {len(boxes)} knots:") for i, box in enumerate(boxes): conf = box.conf[0].item() xyxy = box.xyxy[0].cpu().numpy() print(f" {i+1}. Confidence: {conf:.3f}, BBox: [{xyxy[0]:.0f}, {xyxy[1]:.0f}, {xyxy[2]:.0f}, {xyxy[3]:.0f}]") if save: output_dir = Path("runs/detect") print(f"\n✓ Visualization saved to: {output_dir}") print(f"\n{'='*60}\n") return results def main(): parser = argparse.ArgumentParser(description="RT-DETR inference") parser.add_argument( "--weights", type=Path, required=True, help="Path to trained .pt weights" ) parser.add_argument( "--image", type=Path, required=True, help="Path to input image" ) parser.add_argument( "--threshold", type=float, default=0.5, help="Detection confidence threshold" ) parser.add_argument( "--no-save", action="store_true", help="Don't save visualization" ) args = parser.parse_args() predict(args.weights, args.image, args.threshold, save=not args.no_save) if __name__ == "__main__": main()