Initial commit: Wood knot detection model and GUI
This commit is contained in:
154
train_rtdetr.py
Normal file
154
train_rtdetr.py
Normal file
@ -0,0 +1,154 @@
|
||||
"""
|
||||
Train RT-DETR for knot detection (Apache 2.0 license - free for commercial use).
|
||||
|
||||
RT-DETR is a real-time transformer detector that works well on edge devices like OAK cameras.
|
||||
|
||||
Usage:
|
||||
python train_rtdetr.py --dataset-dir dataset_prepared --model rtdetr-r18 --epochs 100
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
import torch
|
||||
|
||||
|
||||
def train_rtdetr(
|
||||
dataset_dir: Path,
|
||||
output_dir: Path,
|
||||
model_name: str = "rtdetr-r18",
|
||||
epochs: int = 100,
|
||||
batch_size: int = 8,
|
||||
img_size: int = 640,
|
||||
learning_rate: float = 1e-4,
|
||||
):
|
||||
"""
|
||||
Train RT-DETR model.
|
||||
|
||||
Args:
|
||||
dataset_dir: Path to dataset with train/valid/test splits
|
||||
output_dir: Where to save checkpoints
|
||||
model_name: One of ['rtdetr-r18', 'rtdetr-r34', 'rtdetr-r50', 'rtdetr-l']
|
||||
epochs: Number of training epochs
|
||||
batch_size: Batch size
|
||||
img_size: Input image size
|
||||
learning_rate: Learning rate
|
||||
"""
|
||||
from ultralytics import RTDETR
|
||||
|
||||
# Validate dataset structure
|
||||
train_dir = dataset_dir / "train"
|
||||
valid_dir = dataset_dir / "valid"
|
||||
|
||||
if not train_dir.exists() or not valid_dir.exists():
|
||||
raise ValueError(f"Dataset must have train/ and valid/ directories")
|
||||
|
||||
train_ann = train_dir / "_annotations.coco.json"
|
||||
valid_ann = valid_dir / "_annotations.coco.json"
|
||||
|
||||
if not train_ann.exists():
|
||||
raise ValueError(f"Missing {train_ann}")
|
||||
if not valid_ann.exists():
|
||||
raise ValueError(f"Missing {valid_ann}")
|
||||
|
||||
# Create output directory
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create data config file for RT-DETR
|
||||
data_yaml = output_dir / "data.yaml"
|
||||
with data_yaml.open("w") as f:
|
||||
f.write(f"""path: {dataset_dir.absolute()}
|
||||
train: train
|
||||
val: valid
|
||||
|
||||
nc: 1
|
||||
names: ['knot']
|
||||
""")
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Training RT-DETR-{model_name}")
|
||||
print(f"Dataset: {dataset_dir}")
|
||||
print(f"Output: {output_dir}")
|
||||
print(f"Device: {'cuda' if torch.cuda.is_available() else 'cpu'}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
# Map model name to pretrained weights
|
||||
model_map = {
|
||||
"rtdetr-r18": "rtdetr-l.pt", # Use available large model as r18 substitute
|
||||
"rtdetr-r34": "rtdetr-l.pt", # Use available large model as r34 substitute
|
||||
"rtdetr-r50": "rtdetr-l.pt", # Use available large model as r50 substitute
|
||||
"rtdetr-l": "rtdetr-l.pt",
|
||||
}
|
||||
|
||||
if model_name not in model_map:
|
||||
raise ValueError(f"Model must be one of {list(model_map.keys())}")
|
||||
|
||||
# Initialize model (Ultralytics will auto-download pretrained weights)
|
||||
model = RTDETR(model_map[model_name])
|
||||
|
||||
# Train
|
||||
results = model.train(
|
||||
data=str(data_yaml),
|
||||
epochs=epochs,
|
||||
batch=batch_size,
|
||||
imgsz=img_size,
|
||||
lr0=learning_rate,
|
||||
project=str(output_dir),
|
||||
name="training",
|
||||
exist_ok=True,
|
||||
patience=20, # Early stopping
|
||||
save=True,
|
||||
save_period=10, # Save every 10 epochs
|
||||
plots=True,
|
||||
device=0 if torch.cuda.is_available() else "cpu",
|
||||
)
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f"✓ Training complete!")
|
||||
print(f"Best weights: {output_dir / 'training/weights/best.pt'}")
|
||||
print(f"Last weights: {output_dir / 'training/weights/last.pt'}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Train RT-DETR for knot detection")
|
||||
parser.add_argument(
|
||||
"--dataset-dir",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Path to dataset directory with train/valid/test splits"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=Path,
|
||||
default=Path("runs/rtdetr_training"),
|
||||
help="Output directory for checkpoints and logs"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
choices=["rtdetr-r18", "rtdetr-r34", "rtdetr-r50", "rtdetr-l"],
|
||||
default="rtdetr-r18",
|
||||
help="RT-DETR model variant (r18=smallest/fastest, l=largest/most accurate)"
|
||||
)
|
||||
parser.add_argument("--epochs", type=int, default=100, help="Number of epochs")
|
||||
parser.add_argument("--batch-size", type=int, default=8, help="Batch size")
|
||||
parser.add_argument("--img-size", type=int, default=640, help="Input image size")
|
||||
parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
train_rtdetr(
|
||||
dataset_dir=args.dataset_dir,
|
||||
output_dir=args.output_dir,
|
||||
model_name=args.model,
|
||||
epochs=args.epochs,
|
||||
batch_size=args.batch_size,
|
||||
img_size=args.img_size,
|
||||
learning_rate=args.lr,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user