Initial commit: Wood knot detection model and GUI
This commit is contained in:
163
train_yolov6.py
Normal file
163
train_yolov6.py
Normal file
@ -0,0 +1,163 @@
|
||||
"""
|
||||
Train YOLOv6 for knot detection (MIT license - free for commercial use).
|
||||
|
||||
YOLOv6 is from Meituan, optimized for deployment on edge devices.
|
||||
|
||||
Usage:
|
||||
python train_yolov6.py --dataset-dir dataset_prepared --model yolov6n --epochs 100
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
|
||||
def train_yolov6(
|
||||
dataset_dir: Path,
|
||||
output_dir: Path,
|
||||
model_name: str = "yolov6n",
|
||||
epochs: int = 100,
|
||||
batch_size: int = 8,
|
||||
img_size: int = 640,
|
||||
learning_rate: float = 1e-2,
|
||||
):
|
||||
"""
|
||||
Train YOLOv6 model.
|
||||
|
||||
Args:
|
||||
dataset_dir: Path to dataset with train/valid/test splits
|
||||
output_dir: Where to save checkpoints
|
||||
model_name: One of ['yolov6n', 'yolov6s', 'yolov6m', 'yolov6l']
|
||||
epochs: Number of training epochs
|
||||
batch_size: Batch size
|
||||
img_size: Input image size
|
||||
learning_rate: Learning rate
|
||||
"""
|
||||
# Install YOLOv6 if not already installed
|
||||
try:
|
||||
import yolov6
|
||||
except ImportError:
|
||||
print("Installing YOLOv6...")
|
||||
subprocess.check_call([
|
||||
sys.executable, "-m", "pip", "install",
|
||||
"git+https://github.com/meituan/YOLOv6.git"
|
||||
])
|
||||
|
||||
# Validate dataset structure
|
||||
train_dir = dataset_dir / "train"
|
||||
valid_dir = dataset_dir / "valid"
|
||||
|
||||
if not train_dir.exists() or not valid_dir.exists():
|
||||
raise ValueError(f"Dataset must have train/ and valid/ directories")
|
||||
|
||||
train_ann = train_dir / "_annotations.coco.json"
|
||||
valid_ann = valid_dir / "_annotations.coco.json"
|
||||
|
||||
if not train_ann.exists():
|
||||
raise ValueError(f"Missing {train_ann}")
|
||||
if not valid_ann.exists():
|
||||
raise ValueError(f"Missing {valid_ann}")
|
||||
|
||||
# Create output directory
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create data config file for YOLOv6
|
||||
data_yaml = output_dir / "data.yaml"
|
||||
with data_yaml.open("w") as f:
|
||||
f.write(f"""train: {train_dir.absolute()}
|
||||
val: {valid_dir.absolute()}
|
||||
|
||||
nc: 1
|
||||
names: ['knot']
|
||||
""")
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Training YOLOv6-{model_name}")
|
||||
print(f"Dataset: {dataset_dir}")
|
||||
print(f"Output: {output_dir}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
# Map model names
|
||||
model_map = {
|
||||
"yolov6n": "yolov6n",
|
||||
"yolov6s": "yolov6s",
|
||||
"yolov6m": "yolov6m",
|
||||
"yolov6l": "yolov6l",
|
||||
}
|
||||
|
||||
if model_name not in model_map:
|
||||
raise ValueError(f"Model must be one of {list(model_map.keys())}")
|
||||
|
||||
# Build training command
|
||||
yolov6_dir = Path(sys.executable).parent.parent / "YOLOv6"
|
||||
train_script = yolov6_dir / "tools/train.py"
|
||||
|
||||
cmd = [
|
||||
sys.executable,
|
||||
str(train_script),
|
||||
"--batch", str(batch_size),
|
||||
"--conf", str(yolov6_dir / f"configs/{model_name}.py"),
|
||||
"--data", str(data_yaml),
|
||||
"--epochs", str(epochs),
|
||||
"--device", "0",
|
||||
"--name", "yolov6_training",
|
||||
"--output-dir", str(output_dir),
|
||||
]
|
||||
|
||||
print(f"Running: {' '.join(cmd)}\n")
|
||||
|
||||
result = subprocess.run(cmd)
|
||||
|
||||
if result.returncode == 0:
|
||||
print(f"\n{'='*60}")
|
||||
print(f"✓ Training complete!")
|
||||
print(f"Weights saved in: {output_dir}/yolov6_training")
|
||||
print(f"{'='*60}\n")
|
||||
else:
|
||||
print(f"\n❌ Training failed with exit code {result.returncode}")
|
||||
|
||||
return result.returncode == 0
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Train YOLOv6 for knot detection")
|
||||
parser.add_argument(
|
||||
"--dataset-dir",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Path to dataset directory with train/valid/test splits"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=Path,
|
||||
default=Path("runs/yolov6_training"),
|
||||
help="Output directory for checkpoints and logs"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
choices=["yolov6n", "yolov6s", "yolov6m", "yolov6l"],
|
||||
default="yolov6n",
|
||||
help="YOLOv6 model variant (n=smallest/fastest, l=largest/most accurate)"
|
||||
)
|
||||
parser.add_argument("--epochs", type=int, default=100, help="Number of epochs")
|
||||
parser.add_argument("--batch-size", type=int, default=8, help="Batch size")
|
||||
parser.add_argument("--img-size", type=int, default=640, help="Input image size")
|
||||
parser.add_argument("--lr", type=float, default=1e-2, help="Learning rate")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
train_yolov6(
|
||||
dataset_dir=args.dataset_dir,
|
||||
output_dir=args.output_dir,
|
||||
model_name=args.model,
|
||||
epochs=args.epochs,
|
||||
batch_size=args.batch_size,
|
||||
img_size=args.img_size,
|
||||
learning_rate=args.lr,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user