100 lines
2.5 KiB
Python
100 lines
2.5 KiB
Python
|
|
"""
|
||
|
|
Run inference with trained RT-DETR model.
|
||
|
|
|
||
|
|
Usage:
|
||
|
|
python predict_rtdetr.py --weights runs/rtdetr_training/training/weights/best.pt --image test.jpg
|
||
|
|
"""
|
||
|
|
|
||
|
|
import argparse
|
||
|
|
from pathlib import Path
|
||
|
|
import cv2
|
||
|
|
|
||
|
|
|
||
|
|
def predict(weights_path: Path, image_path: Path, threshold: float = 0.5, save: bool = True):
|
||
|
|
"""
|
||
|
|
Run RT-DETR inference on an image.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
weights_path: Path to trained .pt weights
|
||
|
|
image_path: Path to input image
|
||
|
|
threshold: Detection confidence threshold
|
||
|
|
save: Whether to save visualization
|
||
|
|
"""
|
||
|
|
from ultralytics import RTDETR
|
||
|
|
|
||
|
|
if not weights_path.exists():
|
||
|
|
raise ValueError(f"Weights not found: {weights_path}")
|
||
|
|
if not image_path.exists():
|
||
|
|
raise ValueError(f"Image not found: {image_path}")
|
||
|
|
|
||
|
|
print(f"\n{'='*60}")
|
||
|
|
print(f"RT-DETR Inference")
|
||
|
|
print(f"Weights: {weights_path}")
|
||
|
|
print(f"Image: {image_path}")
|
||
|
|
print(f"Threshold: {threshold}")
|
||
|
|
print(f"{'='*60}\n")
|
||
|
|
|
||
|
|
# Load model
|
||
|
|
model = RTDETR(weights_path)
|
||
|
|
|
||
|
|
# Run inference
|
||
|
|
results = model.predict(
|
||
|
|
source=str(image_path),
|
||
|
|
conf=threshold,
|
||
|
|
save=save,
|
||
|
|
show_labels=True,
|
||
|
|
show_conf=True,
|
||
|
|
)
|
||
|
|
|
||
|
|
# Print detections
|
||
|
|
for result in results:
|
||
|
|
boxes = result.boxes
|
||
|
|
print(f"\nDetected {len(boxes)} knots:")
|
||
|
|
for i, box in enumerate(boxes):
|
||
|
|
conf = box.conf[0].item()
|
||
|
|
xyxy = box.xyxy[0].cpu().numpy()
|
||
|
|
print(f" {i+1}. Confidence: {conf:.3f}, BBox: [{xyxy[0]:.0f}, {xyxy[1]:.0f}, {xyxy[2]:.0f}, {xyxy[3]:.0f}]")
|
||
|
|
|
||
|
|
if save:
|
||
|
|
output_dir = Path("runs/detect")
|
||
|
|
print(f"\n✓ Visualization saved to: {output_dir}")
|
||
|
|
|
||
|
|
print(f"\n{'='*60}\n")
|
||
|
|
|
||
|
|
return results
|
||
|
|
|
||
|
|
|
||
|
|
def main():
|
||
|
|
parser = argparse.ArgumentParser(description="RT-DETR inference")
|
||
|
|
parser.add_argument(
|
||
|
|
"--weights",
|
||
|
|
type=Path,
|
||
|
|
required=True,
|
||
|
|
help="Path to trained .pt weights"
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
"--image",
|
||
|
|
type=Path,
|
||
|
|
required=True,
|
||
|
|
help="Path to input image"
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
"--threshold",
|
||
|
|
type=float,
|
||
|
|
default=0.5,
|
||
|
|
help="Detection confidence threshold"
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
"--no-save",
|
||
|
|
action="store_true",
|
||
|
|
help="Don't save visualization"
|
||
|
|
)
|
||
|
|
|
||
|
|
args = parser.parse_args()
|
||
|
|
|
||
|
|
predict(args.weights, args.image, args.threshold, save=not args.no_save)
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
main()
|