aai/main.py

314 lines
10 KiB
Python
Raw Permalink Normal View History

2026-01-26 15:19:15 +00:00
"""
🎯 MAIN - Główny skrypt MiniGPT-60M z checkpointami
"""
import os
import sys
import argparse
import signal
from pathlib import Path
# Dodaj ścieżkę do importów
sys.path.append('.')
from config import logger, cfg, sys_config
import prepare_data
import ai
# ==================== OBSŁUGA SYGNAŁÓW ====================
def signal_handler(sig, frame):
"""Obsługa przerwania (Ctrl+C)"""
logger.info("\n\n⚠️ Trening przerwany przez użytkownika!")
logger.info("💾 Zapisuję stan do wznowienia...")
# Tutaj można dodać zapis stanu przed wyjściem
# W aktualnej implementacji stan jest zapisywany automatycznie
logger.info("✅ Możesz wznowić trening używając: python main.py --cont")
sys.exit(0)
signal.signal(signal.SIGINT, signal_handler)
# ==================== FUNKCJE GŁÓWNE ====================
def prepare_all_data():
"""Przygotowuje wszystkie dane"""
logger.info("=" * 60)
logger.info("📊 PRZYGOTOWYWANIE WSZYSTKICH DANYCH")
logger.info("=" * 60)
preparer = prepare_data.DataPreparer()
preparer.prepare_all_data()
def train_model(resume: bool = False):
"""Trenuje model"""
logger.info("=" * 60)
if resume:
logger.info("🔄 WZNIOWANIE TRENINGU MODELU")
else:
logger.info("🚀 TRENING MODELU MINIGPT-60M")
logger.info("=" * 60)
# Sprawdź czy dane są przygotowane
data_file = Path(cfg.prepared_dir) / "all_data.txt"
if not data_file.exists():
logger.warning("⚠️ Brak przygotowanych danych. Przygotowuję...")
prepare_all_data()
# Uruchom trening
if resume:
ai.continue_training()
else:
ai.train_model(resume=False)
def continue_training():
"""Kontynuuje trening od ostatniego checkpointu"""
train_model(resume=True)
def train_more_epochs(additional_epochs: int = 3):
"""Dodaje więcej epok treningu"""
logger.info("=" * 60)
logger.info(f"📈 DODATKOWY TRENING: {additional_epochs} epok")
logger.info("=" * 60)
# Sprawdź czy dane są przygotowane
data_file = Path(cfg.prepared_dir) / "all_data.txt"
if not data_file.exists():
logger.error("❌ Brak przygotowanych danych!")
logger.info("💡 Uruchom: python main.py --prepare")
return
# Uruchom dodatkowy trening
ai.train_more_epochs(additional_epochs)
def generate_text(prompt: str):
"""Generuje tekst"""
logger.info(f"🎨 GENEROWANIE TEKSTU: '{prompt}'")
ai.generate_text(prompt)
def start_chat():
"""Uruchamia czat"""
ai.start_chat()
def evaluate_model():
"""Ocenia model"""
logger.info("🎯 OCENA MODELU")
# Sprawdź czy model istnieje
model_path = cfg.get_latest_model()
if not model_path:
logger.error("❌ Brak wytrenowanego modelu!")
logger.info("💡 Uruchom: python main.py --train")
return
logger.info(f"✅ Model wczytany: {model_path.name}")
# Tutaj można dodać ewaluację
logger.info("📊 Dodaj ewaluację w ai.py")
def show_checkpoints():
"""Pokazuje dostępne checkpointy"""
logger.info("📁 DOSTĘPNE CHECKPOINTY:")
checkpoints = list(Path(cfg.checkpoints_dir).glob("checkpoint_*.pt"))
models = list(Path(cfg.model_dir).glob("model_*.pt"))
if checkpoints:
logger.info("\n🔽 CHECKPOINTY (do wznowienia):")
for cp in sorted(checkpoints, key=lambda x: x.stat().st_mtime, reverse=True)[:5]:
size_mb = cp.stat().st_size / (1024 * 1024)
logger.info(f"{cp.name} ({size_mb:.1f} MB)")
else:
logger.info(" ❌ Brak checkpointów")
if models:
logger.info("\n🤖 MODELE:")
for model in sorted(models, key=lambda x: x.stat().st_mtime, reverse=True)[:5]:
size_mb = model.stat().st_size / (1024 * 1024)
logger.info(f"{model.name} ({size_mb:.1f} MB)")
# Sprawdź stan resume
resume_state = cfg.load_resume_state()
if resume_state:
logger.info(f"\n🔄 OSTATNI STAN TRENINGU:")
logger.info(f" • Epoka: {resume_state.get('epoch', 'N/A')}")
logger.info(f" • Krok: {resume_state.get('step', 'N/A')}")
logger.info(f" • Loss: {resume_state.get('train_loss', 'N/A')}")
def clean_checkpoints(keep_last: int = 3):
"""Czyści stare checkpointy"""
logger.info(f"🧹 CZYSZCZENIE STARYCH CHECKPOINTÓW (zachowuję {keep_last} najnowszych)")
checkpoints = list(Path(cfg.checkpoints_dir).glob("checkpoint_*.pt"))
if len(checkpoints) <= keep_last:
logger.info("✅ Nie ma czego czyścić")
return
# Sortuj od najstarszego do najnowszego
checkpoints.sort(key=lambda x: x.stat().st_mtime)
# Usuń wszystkie poza keep_last najnowszych
to_delete = checkpoints[:-keep_last]
for cp in to_delete:
try:
cp.unlink()
logger.info(f" 🗑️ Usunięto: {cp.name}")
except Exception as e:
logger.error(f" ❌ Błąd usuwania {cp.name}: {e}")
logger.info(f"✅ Pozostawiono {keep_last} najnowszych checkpointów")
def show_config():
"""Pokazuje konfigurację"""
cfg.print_config()
def run_tests():
"""Uruchamia testy"""
logger.info("🧪 TESTY JEDNOSTKOWE")
# Test tokenizera
from ai import tokenizer
text = "Test tokenizacji"
ids = tokenizer.encode(text)
decoded = tokenizer.decode(ids)
logger.info(f"✅ Tokenizer: '{text}' -> '{decoded}'")
# Test modelu
import torch
model = ai.MiniGPT60M()
x = torch.randint(0, cfg.vocab_size, (2, 32))
logits = model(x)
logger.info(f"✅ Model: logits shape {logits.shape}")
logger.info("✅ Wszystkie testy przeszły pomyślnie!")
# ==================== GŁÓWNA FUNKCJA ====================
def main():
"""Główna funkcja programu"""
parser = argparse.ArgumentParser(
description="🎯 MiniGPT-60M: Zaawansowany model językowy ~60M parametrów",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Przykłady użycia:
python main.py --prepare # Przygotuj dane ze wszystkich folderów data_*
python main.py --train # Trenuj model od początku
python main.py --cont # Wznów trening od ostatniego checkpointu
python main.py --more [N] # Dodaj N epok treningu (domyślnie 3)
python main.py --generate "AI" # Generuj tekst
python main.py --chat # Rozmawiaj z modelem
python main.py --checkpoints # Pokaż dostępne checkpointy
python main.py --clean-cp # Wyczyść stare checkpointy
python main.py --test # Uruchom testy
python main.py --config # Pokaż konfigurację
python main.py --evaluate # Oceń model
python main.py --all # Przygotuj dane i trenuj od początku
Kontynuacja treningu:
Rozpocznij trening: python main.py --train
Przerwij (Ctrl+C): Zapisz stan automatycznie
Wznów: python main.py --cont
Dodaj epoki: python main.py --more 5
"""
)
parser.add_argument("--prepare", action="store_true", help="Przygotuj dane")
parser.add_argument("--train", action="store_true", help="Trening od początku")
parser.add_argument("--cont", action="store_true", help="Kontynuuj trening od checkpointu")
parser.add_argument("--more", type=int, nargs='?', const=3, help="Dodaj epoki treningu (domyślnie 3)")
parser.add_argument("--generate", type=str, help="Generuj tekst z promptu")
parser.add_argument("--chat", action="store_true", help="Tryb rozmowy")
parser.add_argument("--checkpoints", action="store_true", help="Pokaż dostępne checkpointy")
parser.add_argument("--clean-cp", action="store_true", help="Wyczyść stare checkpointy")
parser.add_argument("--test", action="store_true", help="Testy jednostkowe")
parser.add_argument("--config", action="store_true", help="Pokaż konfigurację")
parser.add_argument("--evaluate", action="store_true", help="Oceń model")
parser.add_argument("--all", action="store_true", help="Przygotuj dane i trenuj")
parser.add_argument("--epochs", type=int, default=cfg.epochs, help="Liczba epok")
parser.add_argument("--model", type=str, help="Ścieżka do konkretnego modelu")
args = parser.parse_args()
# Pokaż nagłówek
logger.info("=" * 60)
logger.info("🎯 MINIGPT-60M - Z CHECKPOINTAMI")
logger.info("=" * 60)
# Update liczby epok jeśli podano
if args.epochs != cfg.epochs:
cfg.epochs = args.epochs
logger.info(f"⚙️ Ustawiono {cfg.epochs} epok")
# Uruchom odpowiednią funkcję
if args.prepare:
prepare_all_data()
elif args.train:
train_model(resume=False)
elif args.cont:
continue_training()
elif args.more is not None:
train_more_epochs(additional_epochs=args.more)
elif args.generate:
generate_text(args.generate)
elif args.chat:
start_chat()
elif args.checkpoints:
show_checkpoints()
elif args.clean_cp:
clean_checkpoints()
elif args.test:
run_tests()
elif args.config:
show_config()
elif args.evaluate:
evaluate_model()
elif args.all:
prepare_all_data()
train_model(resume=False)
else:
# Jeśli żadna flaga, pokaż help
logger.info("\n❓ Nie podano flagi. Dostępne opcje:\n")
parser.print_help()
# Pokaż dodatkowe informacje
logger.info("\n💡 PRZYKŁADY UŻYCIA:")
logger.info(" python main.py --train # Trenuj od początku")
logger.info(" [Ctrl+C] # Przerwij trening")
logger.info(" python main.py --cont # Wznów trening")
logger.info(" python main.py --more 5 # Dodaj 5 epok")
logger.info(" python main.py --generate 'AI' # Generuj tekst")
# Sprawdź czy są checkpointy
if cfg.get_latest_checkpoint():
logger.info("\n🔄 Dostępne checkpointy do wznowienia!")
logger.info("=" * 60)
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
logger.info("\n\n👋 Program przerwany przez użytkownika")
sys.exit(0)
except Exception as e:
logger.error(f"\n❌ Krytyczny błąd: {e}")
import traceback
traceback.print_exc()
sys.exit(1)