Initial commit: Wood knot detection model and GUI
This commit is contained in:
81
train_rfdetr.py
Normal file
81
train_rfdetr.py
Normal file
@ -0,0 +1,81 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def _build_model(model_name: str, pretrain_weights: str | None = None):
|
||||
# Import here so `--help` works without heavy imports.
|
||||
from rfdetr import RFDETRBase, RFDETRMedium, RFDETRNano, RFDETRSmall
|
||||
|
||||
model_name = model_name.lower().strip()
|
||||
|
||||
if model_name == "nano":
|
||||
return RFDETRNano(pretrain_weights=pretrain_weights) if pretrain_weights else RFDETRNano()
|
||||
if model_name == "small":
|
||||
return RFDETRSmall(pretrain_weights=pretrain_weights) if pretrain_weights else RFDETRSmall()
|
||||
if model_name == "medium":
|
||||
return RFDETRMedium(pretrain_weights=pretrain_weights) if pretrain_weights else RFDETRMedium()
|
||||
if model_name == "base":
|
||||
return RFDETRBase(pretrain_weights=pretrain_weights) if pretrain_weights else RFDETRBase()
|
||||
|
||||
raise ValueError("--model must be one of: nano, small, medium, base")
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(description="Fine-tune RF-DETR on a COCO-format knot dataset.")
|
||||
parser.add_argument("--dataset-dir", type=Path, required=True, help="Dataset root containing train/valid/test")
|
||||
parser.add_argument("--output-dir", type=Path, required=True, help="Directory where checkpoints/logs are written")
|
||||
parser.add_argument("--model", default="medium", choices=["nano", "small", "medium", "base"], help="Checkpoint size")
|
||||
|
||||
parser.add_argument("--epochs", type=int, default=50)
|
||||
parser.add_argument("--batch-size", type=int, default=4)
|
||||
parser.add_argument("--grad-accum-steps", type=int, default=4)
|
||||
parser.add_argument("--lr", type=float, default=1e-4)
|
||||
|
||||
parser.add_argument("--resume", type=Path, default=None, help="Path to checkpoint.pth to resume training")
|
||||
parser.add_argument(
|
||||
"--pretrain-weights",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Optional: start from a specific weights file instead of the default COCO pretrain",
|
||||
)
|
||||
|
||||
parser.add_argument("--early-stopping", action="store_true", help="Enable early stopping on validation mAP")
|
||||
parser.add_argument("--tensorboard", action="store_true", help="Enable TensorBoard logging")
|
||||
parser.add_argument("--wandb", action="store_true", help="Enable Weights & Biases logging")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
dataset_dir: Path = args.dataset_dir
|
||||
if not dataset_dir.exists():
|
||||
raise SystemExit(f"Dataset dir not found: {dataset_dir}")
|
||||
|
||||
output_dir: Path = args.output_dir
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
model = _build_model(
|
||||
args.model,
|
||||
pretrain_weights=str(args.pretrain_weights) if args.pretrain_weights else None,
|
||||
)
|
||||
|
||||
# Train. RF-DETR handles reading COCO split annotation files.
|
||||
model.train(
|
||||
dataset_dir=str(dataset_dir),
|
||||
epochs=args.epochs,
|
||||
batch_size=args.batch_size,
|
||||
grad_accum_steps=args.grad_accum_steps,
|
||||
lr=args.lr,
|
||||
output_dir=str(output_dir),
|
||||
resume=str(args.resume) if args.resume else None,
|
||||
early_stopping=args.early_stopping,
|
||||
tensorboard=args.tensorboard,
|
||||
wandb=args.wandb,
|
||||
)
|
||||
|
||||
print(f"Training complete. Outputs in: {output_dir}")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
Reference in New Issue
Block a user