Initial commit: Wood knot detection model and GUI
This commit is contained in:
99
predict_rtdetr.py
Normal file
99
predict_rtdetr.py
Normal file
@ -0,0 +1,99 @@
|
||||
"""
|
||||
Run inference with trained RT-DETR model.
|
||||
|
||||
Usage:
|
||||
python predict_rtdetr.py --weights runs/rtdetr_training/training/weights/best.pt --image test.jpg
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
import cv2
|
||||
|
||||
|
||||
def predict(weights_path: Path, image_path: Path, threshold: float = 0.5, save: bool = True):
|
||||
"""
|
||||
Run RT-DETR inference on an image.
|
||||
|
||||
Args:
|
||||
weights_path: Path to trained .pt weights
|
||||
image_path: Path to input image
|
||||
threshold: Detection confidence threshold
|
||||
save: Whether to save visualization
|
||||
"""
|
||||
from ultralytics import RTDETR
|
||||
|
||||
if not weights_path.exists():
|
||||
raise ValueError(f"Weights not found: {weights_path}")
|
||||
if not image_path.exists():
|
||||
raise ValueError(f"Image not found: {image_path}")
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f"RT-DETR Inference")
|
||||
print(f"Weights: {weights_path}")
|
||||
print(f"Image: {image_path}")
|
||||
print(f"Threshold: {threshold}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
# Load model
|
||||
model = RTDETR(weights_path)
|
||||
|
||||
# Run inference
|
||||
results = model.predict(
|
||||
source=str(image_path),
|
||||
conf=threshold,
|
||||
save=save,
|
||||
show_labels=True,
|
||||
show_conf=True,
|
||||
)
|
||||
|
||||
# Print detections
|
||||
for result in results:
|
||||
boxes = result.boxes
|
||||
print(f"\nDetected {len(boxes)} knots:")
|
||||
for i, box in enumerate(boxes):
|
||||
conf = box.conf[0].item()
|
||||
xyxy = box.xyxy[0].cpu().numpy()
|
||||
print(f" {i+1}. Confidence: {conf:.3f}, BBox: [{xyxy[0]:.0f}, {xyxy[1]:.0f}, {xyxy[2]:.0f}, {xyxy[3]:.0f}]")
|
||||
|
||||
if save:
|
||||
output_dir = Path("runs/detect")
|
||||
print(f"\n✓ Visualization saved to: {output_dir}")
|
||||
|
||||
print(f"\n{'='*60}\n")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="RT-DETR inference")
|
||||
parser.add_argument(
|
||||
"--weights",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Path to trained .pt weights"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Path to input image"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--threshold",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="Detection confidence threshold"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-save",
|
||||
action="store_true",
|
||||
help="Don't save visualization"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
predict(args.weights, args.image, args.threshold, save=not args.no_save)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user