""" 🎯 MAIN - Główny skrypt MiniGPT-60M z checkpointami """ import os import sys import argparse import signal from pathlib import Path # Dodaj ścieżkę do importów sys.path.append('.') from config import logger, cfg, sys_config import prepare_data import ai # ==================== OBSŁUGA SYGNAŁÓW ==================== def signal_handler(sig, frame): """Obsługa przerwania (Ctrl+C)""" logger.info("\n\n⚠️ Trening przerwany przez użytkownika!") logger.info("💾 Zapisuję stan do wznowienia...") # Tutaj można dodać zapis stanu przed wyjściem # W aktualnej implementacji stan jest zapisywany automatycznie logger.info("✅ Możesz wznowić trening używając: python main.py --cont") sys.exit(0) signal.signal(signal.SIGINT, signal_handler) # ==================== FUNKCJE GŁÓWNE ==================== def prepare_all_data(): """Przygotowuje wszystkie dane""" logger.info("=" * 60) logger.info("📊 PRZYGOTOWYWANIE WSZYSTKICH DANYCH") logger.info("=" * 60) preparer = prepare_data.DataPreparer() preparer.prepare_all_data() def train_model(resume: bool = False): """Trenuje model""" logger.info("=" * 60) if resume: logger.info("🔄 WZNIOWANIE TRENINGU MODELU") else: logger.info("🚀 TRENING MODELU MINIGPT-60M") logger.info("=" * 60) # Sprawdź czy dane są przygotowane data_file = Path(cfg.prepared_dir) / "all_data.txt" if not data_file.exists(): logger.warning("⚠️ Brak przygotowanych danych. Przygotowuję...") prepare_all_data() # Uruchom trening if resume: ai.continue_training() else: ai.train_model(resume=False) def continue_training(): """Kontynuuje trening od ostatniego checkpointu""" train_model(resume=True) def train_more_epochs(additional_epochs: int = 3): """Dodaje więcej epok treningu""" logger.info("=" * 60) logger.info(f"📈 DODATKOWY TRENING: {additional_epochs} epok") logger.info("=" * 60) # Sprawdź czy dane są przygotowane data_file = Path(cfg.prepared_dir) / "all_data.txt" if not data_file.exists(): logger.error("❌ Brak przygotowanych danych!") logger.info("💡 Uruchom: python main.py --prepare") return # Uruchom dodatkowy trening ai.train_more_epochs(additional_epochs) def generate_text(prompt: str): """Generuje tekst""" logger.info(f"🎨 GENEROWANIE TEKSTU: '{prompt}'") ai.generate_text(prompt) def start_chat(): """Uruchamia czat""" ai.start_chat() def evaluate_model(): """Ocenia model""" logger.info("🎯 OCENA MODELU") # Sprawdź czy model istnieje model_path = cfg.get_latest_model() if not model_path: logger.error("❌ Brak wytrenowanego modelu!") logger.info("💡 Uruchom: python main.py --train") return logger.info(f"✅ Model wczytany: {model_path.name}") # Tutaj można dodać ewaluację logger.info("📊 Dodaj ewaluację w ai.py") def show_checkpoints(): """Pokazuje dostępne checkpointy""" logger.info("📁 DOSTĘPNE CHECKPOINTY:") checkpoints = list(Path(cfg.checkpoints_dir).glob("checkpoint_*.pt")) models = list(Path(cfg.model_dir).glob("model_*.pt")) if checkpoints: logger.info("\n🔽 CHECKPOINTY (do wznowienia):") for cp in sorted(checkpoints, key=lambda x: x.stat().st_mtime, reverse=True)[:5]: size_mb = cp.stat().st_size / (1024 * 1024) logger.info(f" • {cp.name} ({size_mb:.1f} MB)") else: logger.info(" ❌ Brak checkpointów") if models: logger.info("\n🤖 MODELE:") for model in sorted(models, key=lambda x: x.stat().st_mtime, reverse=True)[:5]: size_mb = model.stat().st_size / (1024 * 1024) logger.info(f" • {model.name} ({size_mb:.1f} MB)") # Sprawdź stan resume resume_state = cfg.load_resume_state() if resume_state: logger.info(f"\n🔄 OSTATNI STAN TRENINGU:") logger.info(f" • Epoka: {resume_state.get('epoch', 'N/A')}") logger.info(f" • Krok: {resume_state.get('step', 'N/A')}") logger.info(f" • Loss: {resume_state.get('train_loss', 'N/A')}") def clean_checkpoints(keep_last: int = 3): """Czyści stare checkpointy""" logger.info(f"🧹 CZYSZCZENIE STARYCH CHECKPOINTÓW (zachowuję {keep_last} najnowszych)") checkpoints = list(Path(cfg.checkpoints_dir).glob("checkpoint_*.pt")) if len(checkpoints) <= keep_last: logger.info("✅ Nie ma czego czyścić") return # Sortuj od najstarszego do najnowszego checkpoints.sort(key=lambda x: x.stat().st_mtime) # Usuń wszystkie poza keep_last najnowszych to_delete = checkpoints[:-keep_last] for cp in to_delete: try: cp.unlink() logger.info(f" 🗑️ Usunięto: {cp.name}") except Exception as e: logger.error(f" ❌ Błąd usuwania {cp.name}: {e}") logger.info(f"✅ Pozostawiono {keep_last} najnowszych checkpointów") def show_config(): """Pokazuje konfigurację""" cfg.print_config() def run_tests(): """Uruchamia testy""" logger.info("🧪 TESTY JEDNOSTKOWE") # Test tokenizera from ai import tokenizer text = "Test tokenizacji" ids = tokenizer.encode(text) decoded = tokenizer.decode(ids) logger.info(f"✅ Tokenizer: '{text}' -> '{decoded}'") # Test modelu import torch model = ai.MiniGPT60M() x = torch.randint(0, cfg.vocab_size, (2, 32)) logits = model(x) logger.info(f"✅ Model: logits shape {logits.shape}") logger.info("✅ Wszystkie testy przeszły pomyślnie!") # ==================== GŁÓWNA FUNKCJA ==================== def main(): """Główna funkcja programu""" parser = argparse.ArgumentParser( description="🎯 MiniGPT-60M: Zaawansowany model językowy ~60M parametrów", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Przykłady użycia: python main.py --prepare # Przygotuj dane ze wszystkich folderów data_* python main.py --train # Trenuj model od początku python main.py --cont # Wznów trening od ostatniego checkpointu python main.py --more [N] # Dodaj N epok treningu (domyślnie 3) python main.py --generate "AI" # Generuj tekst python main.py --chat # Rozmawiaj z modelem python main.py --checkpoints # Pokaż dostępne checkpointy python main.py --clean-cp # Wyczyść stare checkpointy python main.py --test # Uruchom testy python main.py --config # Pokaż konfigurację python main.py --evaluate # Oceń model python main.py --all # Przygotuj dane i trenuj od początku Kontynuacja treningu: Rozpocznij trening: python main.py --train Przerwij (Ctrl+C): Zapisz stan automatycznie Wznów: python main.py --cont Dodaj epoki: python main.py --more 5 """ ) parser.add_argument("--prepare", action="store_true", help="Przygotuj dane") parser.add_argument("--train", action="store_true", help="Trening od początku") parser.add_argument("--cont", action="store_true", help="Kontynuuj trening od checkpointu") parser.add_argument("--more", type=int, nargs='?', const=3, help="Dodaj epoki treningu (domyślnie 3)") parser.add_argument("--generate", type=str, help="Generuj tekst z promptu") parser.add_argument("--chat", action="store_true", help="Tryb rozmowy") parser.add_argument("--checkpoints", action="store_true", help="Pokaż dostępne checkpointy") parser.add_argument("--clean-cp", action="store_true", help="Wyczyść stare checkpointy") parser.add_argument("--test", action="store_true", help="Testy jednostkowe") parser.add_argument("--config", action="store_true", help="Pokaż konfigurację") parser.add_argument("--evaluate", action="store_true", help="Oceń model") parser.add_argument("--all", action="store_true", help="Przygotuj dane i trenuj") parser.add_argument("--epochs", type=int, default=cfg.epochs, help="Liczba epok") parser.add_argument("--model", type=str, help="Ścieżka do konkretnego modelu") args = parser.parse_args() # Pokaż nagłówek logger.info("=" * 60) logger.info("🎯 MINIGPT-60M - Z CHECKPOINTAMI") logger.info("=" * 60) # Update liczby epok jeśli podano if args.epochs != cfg.epochs: cfg.epochs = args.epochs logger.info(f"⚙️ Ustawiono {cfg.epochs} epok") # Uruchom odpowiednią funkcję if args.prepare: prepare_all_data() elif args.train: train_model(resume=False) elif args.cont: continue_training() elif args.more is not None: train_more_epochs(additional_epochs=args.more) elif args.generate: generate_text(args.generate) elif args.chat: start_chat() elif args.checkpoints: show_checkpoints() elif args.clean_cp: clean_checkpoints() elif args.test: run_tests() elif args.config: show_config() elif args.evaluate: evaluate_model() elif args.all: prepare_all_data() train_model(resume=False) else: # Jeśli żadna flaga, pokaż help logger.info("\n❓ Nie podano flagi. Dostępne opcje:\n") parser.print_help() # Pokaż dodatkowe informacje logger.info("\n💡 PRZYKŁADY UŻYCIA:") logger.info(" python main.py --train # Trenuj od początku") logger.info(" [Ctrl+C] # Przerwij trening") logger.info(" python main.py --cont # Wznów trening") logger.info(" python main.py --more 5 # Dodaj 5 epok") logger.info(" python main.py --generate 'AI' # Generuj tekst") # Sprawdź czy są checkpointy if cfg.get_latest_checkpoint(): logger.info("\n🔄 Dostępne checkpointy do wznowienia!") logger.info("=" * 60) if __name__ == "__main__": try: main() except KeyboardInterrupt: logger.info("\n\n👋 Program przerwany przez użytkownika") sys.exit(0) except Exception as e: logger.error(f"\n❌ Krytyczny błąd: {e}") import traceback traceback.print_exc() sys.exit(1)