""" 🎯 AI.PY - MiniGPT-60M z Clay Checkpoint System """ import os import sys import time import math import json import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import Dataset, DataLoader from datetime import datetime, timedelta from pathlib import Path from typing import List, Dict, Any, Optional, Tuple import random import numpy as np from config import logger, cfg, sys_config # ==================== TOKENIZER ==================== class Tokenizer: def __init__(self): self.vocab = cfg.vocab self.vocab_size = cfg.vocab_size self.char_to_idx = {ch: i for i, ch in enumerate(self.vocab)} self.idx_to_char = {i: ch for i, ch in enumerate(self.vocab)} def encode(self, text: str) -> List[int]: """Enkoduje tekst na listę indeksów""" return [self.char_to_idx.get(ch, 0) for ch in text if ch in self.char_to_idx] def decode(self, indices: List[int]) -> str: """Dekoduje listę indeksów na tekst""" return ''.join([self.idx_to_char.get(idx, '') for idx in indices]) def encode_batch(self, texts: List[str]) -> torch.Tensor: """Enkoduje batch tekstów""" encoded = [self.encode(text) for text in texts] max_len = max(len(e) for e in encoded) padded = [e + [0] * (max_len - len(e)) for e in encoded] return torch.tensor(padded, dtype=torch.long) def decode_batch(self, tensors: torch.Tensor) -> List[str]: """Dekoduje batch tensorów""" texts = [] for tensor in tensors: indices = tensor.tolist() # Usuń padding (0) indices = [idx for idx in indices if idx != 0] texts.append(self.decode(indices)) return texts # Globalny tokenizer tokenizer = Tokenizer() # ==================== DATASET ==================== class TextDataset(Dataset): def __init__(self, filepath: str, max_len: int = 512): self.filepath = filepath self.max_len = max_len # Wczytaj dane with open(filepath, 'r', encoding='utf-8') as f: self.lines = [line.strip() for line in f if line.strip()] logger.info(f"📊 Dataset: {len(self.lines):,} linii") def __len__(self): return len(self.lines) def __getitem__(self, idx): text = self.lines[idx] # Przycinaj lub paduj do max_len if len(text) > self.max_len: start = random.randint(0, len(text) - self.max_len) text = text[start:start + self.max_len] # Enkoduj encoded = tokenizer.encode(text) # Paduj jeśli za krótkie if len(encoded) < self.max_len: encoded = encoded + [0] * (self.max_len - len(encoded)) x = torch.tensor(encoded[:-1], dtype=torch.long) y = torch.tensor(encoded[1:], dtype=torch.long) return x, y # ==================== CLAY CHECKPOINT SYSTEM ==================== class ClayCheckpoint: """Zaawansowany system checkpointów z Clay UI""" def __init__(self, model, optimizer, scheduler=None): self.model = model self.optimizer = optimizer self.scheduler = scheduler self.checkpoint_dir = Path(cfg.checkpoints_dir) self.checkpoint_dir.mkdir(exist_ok=True) # Training stats self.stats = { 'start_time': time.time(), 'epoch_times': [], 'step_times': [], 'loss_history': [], 'learning_rates': [], 'best_loss': float('inf') } # Progress tracking self.progress = { 'current_step': 0, 'total_steps': 0, 'current_epoch': 0, 'total_epochs': cfg.epochs, 'estimated_completion': None } # Progress bar style self.progress_style = { 'filled': '█', 'empty': '░', 'arrow': '▶', 'spinner': ['⠋', '⠙', '⠹', '⠸', '⠼', '⠴', '⠦', '⠧', '⠇', '⠏'] } def save(self, step, epoch, loss, is_best=False, force=False): """Zapisuje checkpoint""" if not force and step % cfg.checkpoint_freq != 0: return timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") filename = f"clay_checkpoint_ep{epoch}_step{step}_{timestamp}.pt" filepath = self.checkpoint_dir / filename # Przygotuj stan checkpoint = { 'step': step, 'epoch': epoch, 'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'loss': loss, 'stats': self.stats, 'progress': self.progress, 'timestamp': timestamp, 'config': { 'vocab_size': cfg.vocab_size, 'embed_dim': cfg.embed_dim, 'n_layers': cfg.n_layers, 'n_heads': cfg.n_heads } } if self.scheduler: checkpoint['scheduler_state_dict'] = self.scheduler.state_dict() # Zapisz torch.save(checkpoint, filepath) # Zapisz jako JSON dla czytelności json_path = filepath.with_suffix('.json') json_data = { 'checkpoint_info': { 'filename': filename, 'step': step, 'epoch': epoch, 'loss': float(loss), 'timestamp': timestamp, 'file_size': os.path.getsize(filepath) }, 'training_stats': { 'total_time': time.time() - self.stats['start_time'], 'avg_loss': float(sum(self.stats['loss_history'][-100:]) / min(100, len(self.stats['loss_history']))), 'current_lr': float(self.optimizer.param_groups[0]['lr']), 'steps_done': self.progress['current_step'] } } with open(json_path, 'w', encoding='utf-8') as f: json.dump(json_data, f, indent=2, ensure_ascii=False) # Zachowaj tylko N ostatnich checkpointów self.cleanup_old_checkpoints() logger.info(f"💾 Clay Checkpoint: {filename} (loss: {loss:.4f})") if is_best: best_path = self.checkpoint_dir / "clay_best.pt" torch.save(checkpoint, best_path) logger.info(f"🏆 Nowy najlepszy model zapisany!") def load_latest(self): """Wczytuje najnowszy checkpoint""" checkpoints = list(self.checkpoint_dir.glob("clay_checkpoint_*.pt")) if not checkpoints: return None, None, None, None # Znajdź najnowszy po czasie latest = max(checkpoints, key=os.path.getmtime) logger.info(f"🔄 Ładowanie Clay Checkpoint: {latest.name}") checkpoint = torch.load(latest, map_location='cpu') # Przywróć stan self.model.load_state_dict(checkpoint['model_state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) if 'scheduler_state_dict' in checkpoint and self.scheduler: self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) # Przywróć statystyki if 'stats' in checkpoint: self.stats.update(checkpoint['stats']) if 'progress' in checkpoint: self.progress.update(checkpoint['progress']) logger.info(f"✅ Wczytano checkpoint: epoka {checkpoint['epoch']}, krok {checkpoint['step']}") return checkpoint['epoch'], checkpoint['step'], checkpoint['loss'], checkpoint.get('timestamp') def cleanup_old_checkpoints(self): """Czyści stare checkpointy""" checkpoints = list(self.checkpoint_dir.glob("clay_checkpoint_*.pt")) checkpoints.sort(key=os.path.getmtime) keep = cfg.get('keep_checkpoints', 5) while len(checkpoints) > keep: old = checkpoints.pop(0) # Zostaw najlepszy if "best" not in old.name: old.unlink() # Usuń też JSON json_file = old.with_suffix('.json') if json_file.exists(): json_file.unlink() def estimate_time_remaining(self): """Szacuje pozostały czas treningu""" if not self.stats['step_times']: return "Obliczanie..." avg_time_per_step = sum(self.stats['step_times'][-100:]) / min(100, len(self.stats['step_times'])) steps_done = self.progress['current_step'] total_steps = self.progress['total_steps'] if steps_done == 0 or total_steps == 0: return "Obliczanie..." steps_left = total_steps - steps_done seconds_left = steps_left * avg_time_per_step # Jeśli mamy dane z epok, użyj ich if self.stats['epoch_times']: epochs_done = self.progress['current_epoch'] epochs_left = self.progress['total_epochs'] - epochs_done avg_epoch_time = sum(self.stats['epoch_times']) / len(self.stats['epoch_times']) epoch_based = epochs_left * avg_epoch_time seconds_left = min(seconds_left, epoch_based) return self._format_time(seconds_left) def _format_time(self, seconds): """Formatuje czas""" if seconds < 60: return f"{seconds:.0f}s" elif seconds < 3600: return f"{seconds/60:.1f}m" elif seconds < 86400: return f"{seconds/3600:.1f}h" else: return f"{seconds/86400:.1f}d" def get_progress_bar(self, width=40): """Generuje pasek postępu""" if self.progress['total_steps'] == 0: return "[░░░░░░░░░░░░░░░░░░░░]" progress = min(1.0, self.progress['current_step'] / self.progress['total_steps']) filled = int(width * progress) empty = width - filled bar = self.progress_style['filled'] * filled if filled < width: bar += self.progress_style['arrow'] bar += self.progress_style['empty'] * empty return f"[{bar}]" def print_progress(self, step_loss, current_lr): """Wyświetla aktualny postęp""" progress_bar = self.get_progress_bar(40) # Oblicz procent percent = (self.progress['current_step'] / self.progress['total_steps']) * 100 # Szacowany czas eta = self.estimate_time_remaining() # Spinner (animowany) spinner_idx = int(time.time() * 4) % len(self.progress_style['spinner']) spinner = self.progress_style['spinner'][spinner_idx] # Formatuj wyjście output = (f"\r{spinner} {progress_bar} {percent:5.1f}% | " f"Step: {self.progress['current_step']:,}/{self.progress['total_steps']:,} | " f"Loss: {step_loss:.4f} | LR: {current_lr:.6f} | " f"ETA: {eta}") print(output, end='', flush=True) # Na końcu epoki, przejdź do nowej linii if self.progress['current_step'] == self.progress['total_steps']: print() def start_epoch(self, epoch, total_batches): """Rozpoczyna nową epokę""" self.progress['current_epoch'] = epoch self.progress['total_steps'] = total_batches * self.progress['total_epochs'] self.epoch_start_time = time.time() logger.info(f"\n🚀 ROZPOCZYNAM EPOKĘ {epoch+1}/{self.progress['total_epochs']}") logger.info(f" • Kroki w epoce: {total_batches:,}") logger.info(f" • Całkowite kroki: {self.progress['total_steps']:,}") # Szacowanie czasu dla tej epoki if self.stats['epoch_times']: avg_epoch_time = sum(self.stats['epoch_times']) / len(self.stats['epoch_times']) logger.info(f" • Szacowany czas epoki: {self._format_time(avg_epoch_time)}") logger.info(f" • Szacowany czas do końca: {self.estimate_time_remaining()}") def end_epoch(self, epoch_loss): """Kończy epokę""" epoch_time = time.time() - self.epoch_start_time self.stats['epoch_times'].append(epoch_time) # Aktualizuj najlepszy loss if epoch_loss < self.stats['best_loss']: self.stats['best_loss'] = epoch_loss # Oblicz pozostały czas avg_epoch_time = sum(self.stats['epoch_times']) / len(self.stats['epoch_times']) epochs_left = self.progress['total_epochs'] - self.progress['current_epoch'] - 1 total_time_left = avg_epoch_time * epochs_left # Podsumowanie epoki logger.info(f"📊 EPOKA {self.progress['current_epoch']+1} ZAKOŃCZONA:") logger.info(f" • Loss: {epoch_loss:.4f}") logger.info(f" • Najlepszy loss: {self.stats['best_loss']:.4f}") logger.info(f" • Czas epoki: {self._format_time(epoch_time)}") logger.info(f" • Średni czas/epokę: {self._format_time(avg_epoch_time)}") logger.info(f" • Pozostało: ~{self._format_time(total_time_left)}") # Przewidywane zakończenie if epochs_left > 0: eta_time = datetime.now() + timedelta(seconds=total_time_left) logger.info(f" • Przewidywane zakończenie: {eta_time.strftime('%Y-%m-%d %H:%M:%S')}") # ==================== MODEL ARCHITECTURE ==================== class AttentionHead(nn.Module): def __init__(self, embed_dim, head_dim, dropout=0.1): super().__init__() self.head_dim = head_dim self.scale = head_dim ** -0.5 self.qkv = nn.Linear(embed_dim, 3 * head_dim) self.proj = nn.Linear(head_dim, embed_dim) self.dropout = nn.Dropout(dropout) def forward(self, x, mask=None): B, T, C = x.shape qkv = self.qkv(x) q, k, v = qkv.chunk(3, dim=-1) attn = (q @ k.transpose(-2, -1)) * self.scale if mask is not None: attn = attn.masked_fill(mask == 0, float('-inf')) attn = attn.softmax(dim=-1) attn = self.dropout(attn) out = attn @ v out = self.proj(out) return out class MultiHeadAttention(nn.Module): def __init__(self, embed_dim, num_heads, dropout=0.1): super().__init__() assert embed_dim % num_heads == 0 self.head_dim = embed_dim // num_heads self.num_heads = num_heads self.heads = nn.ModuleList([ AttentionHead(embed_dim, self.head_dim, dropout) for _ in range(num_heads) ]) self.proj = nn.Linear(embed_dim, embed_dim) self.dropout = nn.Dropout(dropout) def forward(self, x, mask=None): # Równoległe przetwarzanie przez wszystkie głowy head_outputs = [head(x, mask) for head in self.heads] # Konkatenacja out = torch.cat(head_outputs, dim=-1) out = self.proj(out) out = self.dropout(out) return out class FeedForward(nn.Module): def __init__(self, embed_dim, ff_dim, dropout=0.1): super().__init__() self.net = nn.Sequential( nn.Linear(embed_dim, ff_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(ff_dim, embed_dim), nn.Dropout(dropout) ) def forward(self, x): return self.net(x) class TransformerBlock(nn.Module): def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1): super().__init__() self.attention = MultiHeadAttention(embed_dim, num_heads, dropout) self.ff = FeedForward(embed_dim, ff_dim, dropout) self.norm1 = nn.LayerNorm(embed_dim) self.norm2 = nn.LayerNorm(embed_dim) def forward(self, x, mask=None): # Self-attention z residual connection attn_out = self.attention(self.norm1(x), mask) x = x + attn_out # Feed forward z residual connection ff_out = self.ff(self.norm2(x)) x = x + ff_out return x class MiniGPT60M(nn.Module): def __init__(self): super().__init__() # Embeddings self.token_embedding = nn.Embedding(cfg.vocab_size, cfg.embed_dim) self.position_embedding = nn.Embedding(cfg.max_len, cfg.embed_dim) # Transformer blocks self.blocks = nn.ModuleList([ TransformerBlock(cfg.embed_dim, cfg.n_heads, cfg.ff_dim, cfg.dropout) for _ in range(cfg.n_layers) ]) # Final layers self.norm = nn.LayerNorm(cfg.embed_dim) self.head = nn.Linear(cfg.embed_dim, cfg.vocab_size) # Initialize weights self.apply(self._init_weights) # Count parameters total_params = sum(p.numel() for p in self.parameters()) trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) logger.info(f"🤖 Model MiniGPT-60M stworzony") logger.info(f" • Parametry: {total_params:,} (trainable: {trainable_params:,})") logger.info(f" • Embed dim: {cfg.embed_dim}") logger.info(f" • Warstwy: {cfg.n_layers}") logger.info(f" • Głowy: {cfg.n_heads}") def _init_weights(self, module): if isinstance(module, nn.Linear): nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): nn.init.normal_(module.weight, mean=0.0, std=0.02) def forward(self, idx, targets=None): B, T = idx.shape # Embeddings tok_emb = self.token_embedding(idx) pos = torch.arange(T, device=idx.device).unsqueeze(0) pos_emb = self.position_embedding(pos) x = tok_emb + pos_emb # Causal mask mask = torch.tril(torch.ones(T, T, device=idx.device)).view(1, 1, T, T) # Transformer blocks for block in self.blocks: x = block(x, mask) # Final layer x = self.norm(x) logits = self.head(x) # Loss if targets provided loss = None if targets is not None: loss = nn.functional.cross_entropy( logits.view(-1, logits.size(-1)), targets.view(-1) ) return logits, loss def generate(self, input_ids=None, max_length=100, temperature=1.0, **kwargs): """Uniwersalna metoda generate obsługująca różne interfejsy""" # Ignoruj nieobsługiwane argumenty (device, max_len, etc.) max_len = kwargs.get('max_len', max_length) if isinstance(input_ids, str): return self.generate_text(input_ids, max_len=max_len, temperature=temperature) elif input_ids is not None: # Jeśli to tensor, przekonwertuj na tekst if hasattr(input_ids, 'tolist'): if len(input_ids.shape) > 1: text = tokenizer.decode(input_ids[0].tolist()) else: text = tokenizer.decode(input_ids.tolist()) else: text = tokenizer.decode(input_ids) return self.generate_text(text, max_len=max_len, temperature=temperature) else: return "" # Pusty string dla None input def generate_text(self, prompt, max_len=100, temperature=1.0): """Generuje tekst na podstawie promptu""" self.eval() # Encode prompt input_ids = tokenizer.encode(prompt) if len(input_ids) == 0: return prompt # Przygotuj tensor input_tensor = torch.tensor([input_ids], dtype=torch.long) if torch.cuda.is_available(): input_tensor = input_tensor.cuda() generated = input_ids.copy() with torch.no_grad(): for _ in range(max_len): # Forward pass logits, _ = self(input_tensor) # Weź logity dla ostatniego tokena next_token_logits = logits[0, -1, :] / temperature # Top-k sampling if cfg.top_k > 0: indices_to_remove = next_token_logits < torch.topk(next_token_logits, cfg.top_k)[0][..., -1, None] next_token_logits[indices_to_remove] = float('-inf') # Top-p sampling if cfg.top_p < 1.0: sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True) cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_remove = cumulative_probs > cfg.top_p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices[sorted_indices_to_remove] next_token_logits[indices_to_remove] = float('-inf') # Sample next token probs = torch.softmax(next_token_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1).item() # Dodaj do sekwencji generated.append(next_token) # Aktualizuj input input_tensor = torch.tensor([generated[-cfg.max_len:]], dtype=torch.long) if torch.cuda.is_available(): input_tensor = input_tensor.cuda() # Sprawdź token kończący if next_token == tokenizer.char_to_idx.get('\n', -1): break # Decode generated_text = tokenizer.decode(generated) return generated_text # ==================== TRAINING FUNCTIONS ==================== def create_optimizer(model, learning_rate=3e-4, weight_decay=0.1): """Tworzy optimizer z dekayem wag""" decay_params = [] no_decay_params = [] for name, param in model.named_parameters(): if not param.requires_grad: continue if 'weight' in name and len(param.shape) > 1: decay_params.append(param) else: no_decay_params.append(param) optimizer = optim.AdamW([ {'params': decay_params, 'weight_decay': weight_decay}, {'params': no_decay_params, 'weight_decay': 0.0} ], lr=learning_rate, betas=(cfg.adam_beta1, cfg.adam_beta2), eps=cfg.adam_eps) return optimizer def create_scheduler(optimizer, warmup_steps=2000): """Tworzy scheduler z warmupem""" scheduler = optim.lr_scheduler.OneCycleLR( optimizer, max_lr=cfg.learning_rate, total_steps=cfg.epochs * 1000, # Szacunkowo pct_start=warmup_steps / (cfg.epochs * 1000), anneal_strategy='cos' ) return scheduler def benchmark_device(): """Testuje wydajność urządzenia""" logger.info("🏃‍♂️ BENCHMARK URZĄDZENIA:") if torch.cuda.is_available(): device_name = torch.cuda.get_device_name(0) memory_gb = torch.cuda.get_device_properties(0).total_memory / 1e9 logger.info(f" • GPU: {device_name}") logger.info(f" • Pamięć GPU: {memory_gb:.1f} GB") # Test wydajności GPU start = time.time() x = torch.randn(1024, 768, device='cuda') y = torch.randn(768, 512, device='cuda') for _ in range(100): _ = torch.matmul(x, y) torch.cuda.synchronize() gpu_time = time.time() - start logger.info(f" • Wydajność GPU: {gpu_time:.2f}s na 100 operacji") # Test CPU x = torch.randn(1024, 768) y = torch.randn(768, 512) start = time.time() for _ in range(100): _ = torch.matmul(x, y) cpu_time = time.time() - start logger.info(f" • Wydajność CPU: {cpu_time:.2f}s na 100 operacji") if torch.cuda.is_available(): speedup = cpu_time / gpu_time logger.info(f" • Przyśpieszenie GPU vs CPU: {speedup:.1f}x") def estimate_training_time(total_samples, batch_size, epochs, device): """Szacuje czas treningu przed rozpoczęciem""" steps_per_epoch = math.ceil(total_samples / batch_size) total_steps = steps_per_epoch * epochs logger.info("⏱️ SZACOWANIE CZASU TRENINGU:") logger.info(f" • Próbki: {total_samples:,}") logger.info(f" • Epoki: {epochs}") logger.info(f" • Batch size: {batch_size}") logger.info(f" • Kroki: {total_steps:,}") # Benchmarki dla różnych urządzeń (sekundy na 1000 kroków) benchmarks = { 'cuda': {'T4': 120, 'P100': 90, 'V100': 60, 'A100': 30, 'default': 100}, 'mps': {'M1': 150, 'M2': 100, 'default': 120}, 'cpu': {'default': 600} } steps_in_k = total_steps / 1000 if device == 'cuda': try: gpu_name = torch.cuda.get_device_name(0) if 'T4' in gpu_name: time_per_k = benchmarks['cuda']['T4'] elif 'P100' in gpu_name: time_per_k = benchmarks['cuda']['P100'] elif 'V100' in gpu_name: time_per_k = benchmarks['cuda']['V100'] elif 'A100' in gpu_name: time_per_k = benchmarks['cuda']['A100'] else: time_per_k = benchmarks['cuda']['default'] except: time_per_k = benchmarks['cuda']['default'] elif device == 'mps': time_per_k = benchmarks['mps']['default'] else: time_per_k = benchmarks['cpu']['default'] estimated_total = steps_in_k * time_per_k # Formatuj czas if estimated_total > 3600 * 24: days = estimated_total / (3600 * 24) time_str = f"~{days:.1f} dni" elif estimated_total > 3600: hours = estimated_total / 3600 time_str = f"~{hours:.1f} godzin" elif estimated_total > 60: minutes = estimated_total / 60 time_str = f"~{minutes:.1f} minut" else: time_str = f"~{estimated_total:.0f} sekund" logger.info(f" • Szacowany czas: {time_str}") # Przewidywane zakończenie eta = datetime.now() + timedelta(seconds=estimated_total) logger.info(f" • Przewidywane zakończenie: {eta.strftime('%Y-%m-%d %H:%M:%S')}") return total_steps, estimated_total def train_model(resume=False): """Główna funkcja treningowa""" logger.info("=" * 60) logger.info("🚀 ROZPOCZĘCIE TRENINGU MINIGPT-60M") logger.info("=" * 60) # Setup urządzenia device = sys_config.device if device == 'cuda' and not torch.cuda.is_available(): logger.warning("⚠ CUDA niedostępne, używam CPU") device = 'cpu' logger.info(f"🖥️ Urządzenie: {device.upper()}") # Benchmark benchmark_device() # Wczytaj dane data_file = Path(cfg.prepared_dir) / "all_data.txt" if not data_file.exists(): logger.error(f"❌ Brak danych: {data_file}") logger.info("💡 Uruchom: python main.py --prepare") return dataset = TextDataset(str(data_file), max_len=cfg.max_len) train_loader = DataLoader( dataset, batch_size=cfg.batch_size, shuffle=True, num_workers=cfg.num_workers, pin_memory=cfg.pin_memory ) # Stwórz model model = MiniGPT60M() model.to(device) # Loss function criterion = nn.CrossEntropyLoss() # Optimizer & scheduler optimizer = create_optimizer(model, cfg.learning_rate, cfg.weight_decay) scheduler = create_scheduler(optimizer, cfg.warmup_steps) # Clay Checkpoint System clay = ClayCheckpoint(model, optimizer, scheduler) # Resume jeśli potrzebne start_epoch = 0 start_step = 0 if resume: loaded_epoch, loaded_step, loaded_loss, timestamp = clay.load_latest() if loaded_epoch is not None: start_epoch = loaded_epoch start_step = loaded_step logger.info(f"🔄 Wznawianie z epoki {start_epoch}, kroku {start_step}") logger.info(f" • Ostatni loss: {loaded_loss:.4f}") logger.info(f" • Timestamp: {timestamp}") else: logger.warning("⚠ Nie znaleziono checkpointu, zaczynam od początku") # Szacowanie czasu total_steps, estimated_total = estimate_training_time( len(dataset), cfg.batch_size, cfg.epochs - start_epoch, device ) logger.info("=" * 60) logger.info("🎬 ROZPOCZĘTO TRENING!") logger.info("=" * 60) # Główna pętla treningowa best_loss = float('inf') for epoch in range(start_epoch, cfg.epochs): clay.start_epoch(epoch, len(train_loader)) epoch_loss = 0 model.train() for batch_idx, (x, y) in enumerate(train_loader): step_start_time = time.time() x, y = x.to(device), y.to(device) # Forward pass logits, loss = model(x, y) # Backward pass optimizer.zero_grad() loss.backward() # Gradient clipping torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.clip_grad) optimizer.step() if scheduler: scheduler.step() # Aktualizuj statystyki step_time = time.time() - step_start_time clay.stats['step_times'].append(step_time) clay.stats['loss_history'].append(loss.item()) clay.stats['learning_rates'].append(optimizer.param_groups[0]['lr']) # Aktualizuj progress clay.progress['current_step'] = epoch * len(train_loader) + batch_idx epoch_loss += loss.item() # Wyświetl progress co 10 kroków lub 1% batchy if batch_idx % max(1, len(train_loader) // 100) == 0 or batch_idx < 10: clay.print_progress(loss.item(), optimizer.param_groups[0]['lr']) # Zapisz checkpoint co 1000 kroków if clay.progress['current_step'] % 1000 == 0 and clay.progress['current_step'] > 0: is_best = loss.item() < best_loss if is_best: best_loss = loss.item() clay.save( clay.progress['current_step'], epoch, loss.item(), is_best=is_best ) # Koniec epoki avg_epoch_loss = epoch_loss / len(train_loader) clay.end_epoch(avg_epoch_loss) # Zapisz model co epokę model_path = Path(cfg.model_dir) / f"minigpt_epoch_{epoch + 1}.pt" torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict() if scheduler else None, 'loss': avg_epoch_loss, 'config': cfg.__dict__ }, model_path) logger.info(f"💾 Model zapisany: {model_path.name}") # Zapisz też jako najlepszy jeśli to najlepszy loss if avg_epoch_loss < best_loss: best_loss = avg_epoch_loss best_path = Path(cfg.model_dir) / "minigpt_best.pt" torch.save(model.state_dict(), best_path) logger.info(f"🏆 Nowy najlepszy model! Loss: {avg_epoch_loss:.4f}") # ==================== KONIEC TRENINGU ==================== total_time = time.time() - clay.stats['start_time'] logger.info("=" * 60) logger.info("🎉 TRENING ZAKOŃCZONY!") logger.info("=" * 60) # Podsumowanie statystyk logger.info(f"📊 PODSUMOWANIE:") logger.info(f" • Całkowity czas: {clay._format_time(total_time)}") logger.info(f" • Ostatni loss: {clay.stats['loss_history'][-1]:.4f}") logger.info(f" • Najlepszy loss: {clay.stats['best_loss']:.4f}") logger.info(f" • Średni loss: {sum(clay.stats['loss_history']) / len(clay.stats['loss_history']):.4f}") logger.info(f" • Średni czas/krok: {sum(clay.stats['step_times']) / len(clay.stats['step_times']):.3f}s") logger.info(f" • Średni czas/epokę: {sum(clay.stats['epoch_times']) / len(clay.stats['epoch_times']):.1f}s") logger.info(f" • Całkowite kroki: {clay.progress['current_step']:,}") # Zapisz finalne statystyki stats_path = Path(cfg.checkpoints_dir) / "training_stats.json" final_stats = { 'total_time': total_time, 'best_loss': float(clay.stats['best_loss']), 'final_loss': float(clay.stats['loss_history'][-1]), 'avg_loss': float(sum(clay.stats['loss_history']) / len(clay.stats['loss_history'])), 'total_steps': int(clay.progress['current_step']), 'total_epochs': int(cfg.epochs), 'learning_rate_history': [float(lr) for lr in clay.stats['learning_rates']], 'loss_history': [float(loss) for loss in clay.stats['loss_history']], 'completion_time': datetime.now().strftime("%Y-%m-%d %H:%M:%S") } with open(stats_path, 'w', encoding='utf-8') as f: json.dump(final_stats, f, indent=2, ensure_ascii=False) logger.info(f"📈 Statystyki zapisane: {stats_path}") # Zapisz finalny model final_path = Path(cfg.model_dir) / "minigpt_final.pt" torch.save({ 'model_state_dict': model.state_dict(), 'config': cfg.__dict__, 'stats': final_stats }, final_path) logger.info(f"💾 Finalny model zapisany: {final_path}") # Test generowania logger.info(f"\n🧪 TEST GENEROWANIA:") test_prompts = [ "Witaj, ", "Python to ", "Dzisiaj jest ", "AI to " ] for prompt in test_prompts: generated = model.generate_text(prompt, max_len=50, temperature=0.8) logger.info(f" • '{prompt}' -> '{generated[:50]}...'") logger.info("=" * 60) logger.info("✅ TRENING ZAKOŃCZONY POMYŚLNIE!") logger.info("=" * 60) def continue_training(): """Kontynuuje trening od ostatniego checkpointu""" logger.info("🔄 KONTYNUUJĘ TRENING OD CHECKPOINTU") train_model(resume=True) def train_more_epochs(additional_epochs=3): """Dodaje więcej epok treningu""" logger.info(f"📈 DODAJĘ {additional_epochs} EPOK TRENINGU") # Zaktualizuj liczbę epok w konfiguracji original_epochs = cfg.epochs cfg.epochs = original_epochs + additional_epochs logger.info(f" • Obecne epoki: {original_epochs}") logger.info(f" • Nowe epoki: {cfg.epochs}") logger.info(f" • Dodaję: {additional_epochs} epok") # Kontynuuj trening train_model(resume=True) # Przywróć oryginalną wartość cfg.epochs = original_epochs def load_model(model_path=None, device=None): """Wczytuje model z pliku""" if device is None: device = sys_config.device if model_path is None: # Szukaj najnowszego modelu models = list(Path(cfg.model_dir).glob("*.pt")) if not models: logger.error("❌ Brak modelu do wczytania!") return None # Preferuj najlepszy, potem finalny, potem najnowszy best_path = Path(cfg.model_dir) / "minigpt_best.pt" final_path = Path(cfg.model_dir) / "minigpt_final.pt" if best_path.exists(): model_path = best_path logger.info("📂 Wczytuję najlepszy model") elif final_path.exists(): model_path = final_path logger.info("📂 Wczytuję finalny model") else: models.sort(key=lambda x: x.stat().st_mtime, reverse=True) model_path = models[0] logger.info(f"📂 Wczytuję najnowszy model: {model_path.name}") logger.info(f"🔄 Wczytywanie modelu: {model_path}") try: checkpoint = torch.load(model_path, map_location='cpu') # Stwórz model model = MiniGPT60M() if 'model_state_dict' in checkpoint: model.load_state_dict(checkpoint['model_state_dict']) else: # Stary format - same wagi model.load_state_dict(checkpoint) model.to(device) model.eval() logger.info(f"✅ Model wczytany pomyślnie") # Pokaż informacje if 'config' in checkpoint: logger.info(f" • Config: {checkpoint['config'].get('n_layers', 'N/A')} warstw") if 'loss' in checkpoint: logger.info(f" • Loss: {checkpoint['loss']:.4f}") if 'epoch' in checkpoint: logger.info(f" • Epoka: {checkpoint['epoch']}") return model except Exception as e: logger.error(f"❌ Błąd wczytywania modelu: {e}") return None def generate_text(prompt, model_path=None, device=None, max_length=200, temperature=0.8): """Generuje tekst na podstawie promptu""" logger.info(f"🎨 GENERUJĘ TEKST: '{prompt}'") # Wczytaj model model = load_model(model_path, device) if model is None: logger.error("❌ Nie można wczytać modelu!") return # Generuj start_time = time.time() generated = model.generate_text(prompt, max_len=max_length, temperature=temperature) gen_time = time.time() - start_time # Pokaż wynik logger.info("=" * 60) logger.info(f"📝 WYJŚCIE:") logger.info(generated) logger.info("=" * 60) logger.info(f"⏱️ Czas generowania: {gen_time:.2f}s") logger.info(f"📏 Długość: {len(generated)} znaków") # Zapisz do pliku output_dir = Path("results") output_dir.mkdir(exist_ok=True) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") output_file = output_dir / f"generated_{timestamp}.txt" with open(output_file, 'w', encoding='utf-8') as f: f.write(f"Prompt: {prompt}\n") f.write(f"Generated at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") f.write(f"Temperature: {temperature}\n") f.write(f"Max length: {max_length}\n") f.write("-" * 60 + "\n") f.write(generated + "\n") logger.info(f"💾 Wynik zapisany: {output_file}") return generated def start_chat(model_path=None, device=None): """Uruchamia tryb rozmowy z modelem""" logger.info("💬 ROZPOCZYNAM TRYB ROZMOWY") logger.info("=" * 60) logger.info("Wpisz 'exit' aby wyjść") logger.info("Wpisz 'reset' aby zresetować kontekst") logger.info("=" * 60) # Wczytaj model model = load_model(model_path, device) if model is None: return # Historia rozmowy conversation_history = [] max_history = 5 while True: try: # Pobierz input user_input = input("\n👤 Ty: ").strip() if user_input.lower() == 'exit': logger.info("👋 Do widzenia!") break if user_input.lower() == 'reset': conversation_history = [] logger.info("🔄 Historia zresetowana") continue if not user_input: continue # Przygotuj prompt z historią if conversation_history: prompt = "\n".join(conversation_history[-max_history:]) + "\n👤 " + user_input + "\n🤖 " else: prompt = "👤 " + user_input + "\n🤖 " # Generuj odpowiedź logger.info("🤖 Myślę...") start_time = time.time() response = model.generate_text( prompt, max_len=200, temperature=0.8 ) # Wyodrębnij tylko odpowiedź modelu if "🤖 " in response: response = response.split("🤖 ")[-1].strip() gen_time = time.time() - start_time # Wyświetl odpowiedź print(f"🤖 AI: {response}") print(f" ⏱️ {gen_time:.2f}s") # Dodaj do historii conversation_history.append(f"👤 {user_input}") conversation_history.append(f"🤖 {response}") # Ogranicz historię if len(conversation_history) > max_history * 2: conversation_history = conversation_history[-(max_history * 2):] except KeyboardInterrupt: logger.info("\n👋 Przerwano przez użytkownika") break except Exception as e: logger.error(f"❌ Błąd: {e}") continue def evaluate_model(model_path=None, device=None, test_samples=100): """Ocenia model na danych testowych""" logger.info("📊 OCENIAM MODEL") # Wczytaj model model = load_model(model_path, device) if model is None: return # Wczytaj dane data_file = Path(cfg.prepared_dir) / "all_data.txt" if not data_file.exists(): logger.error("❌ Brak danych do ewaluacji") return dataset = TextDataset(str(data_file), max_len=cfg.max_len) # Wybierz próbki testowe test_indices = random.sample(range(len(dataset)), min(test_samples, len(dataset))) model.eval() total_loss = 0 total_perplexity = 0 logger.info(f"🔍 Testowanie na {len(test_indices)} próbkach...") with torch.no_grad(): for i, idx in enumerate(test_indices): x, y = dataset[idx] x = x.unsqueeze(0).to(device) y = y.unsqueeze(0).to(device) _, loss = model(x, y) total_loss += loss.item() # Perplexity perplexity = torch.exp(loss).item() total_perplexity += perplexity if (i + 1) % 10 == 0: logger.info(f" Przetworzono {i + 1}/{len(test_indices)}...") avg_loss = total_loss / len(test_indices) avg_perplexity = total_perplexity / len(test_indices) logger.info("=" * 60) logger.info("📈 WYNIKI EWALUACJI:") logger.info(f" • Średni Loss: {avg_loss:.4f}") logger.info(f" • Perplexity: {avg_perplexity:.2f}") logger.info(f" • Próbki testowe: {len(test_indices)}") logger.info("=" * 60) # Zapisz wyniki eval_dir = Path("evaluation") eval_dir.mkdir(exist_ok=True) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") eval_file = eval_dir / f"eval_{timestamp}.json" results = { 'timestamp': timestamp, 'avg_loss': float(avg_loss), 'avg_perplexity': float(avg_perplexity), 'test_samples': len(test_indices), 'model': str(model_path) } with open(eval_file, 'w', encoding='utf-8') as f: json.dump(results, f, indent=2, ensure_ascii=False) logger.info(f"💾 Wyniki zapisane: {eval_file}") return avg_loss, avg_perplexity # ==================== MAIN GUARD ==================== if __name__ == "__main__": # Testowa funkcja print("🤖 MiniGPT-60M AI Module") print("Użyj: python main.py --train / --generate / --chat") # Test tokenizera test_text = "Hello world!" encoded = tokenizer.encode(test_text) decoded = tokenizer.decode(encoded) print(f"\n🧪 Test tokenizera: '{test_text}' -> {encoded} -> '{decoded}'") # Test modelu model = MiniGPT60M() test_input = torch.randint(0, cfg.vocab_size, (2, 16)) logits, _ = model(test_input) print(f"🧪 Test modelu: input {test_input.shape} -> output {logits.shape}") print(f"✅ Moduł AI gotowy do użycia!")