1247 lines
42 KiB
Python
1247 lines
42 KiB
Python
|
|
"""
|
||
|
|
🎯 AI.PY - MiniGPT-60M z Clay Checkpoint System
|
||
|
|
"""
|
||
|
|
|
||
|
|
import os
|
||
|
|
import sys
|
||
|
|
import time
|
||
|
|
import math
|
||
|
|
import json
|
||
|
|
import torch
|
||
|
|
import torch.nn as nn
|
||
|
|
import torch.optim as optim
|
||
|
|
from torch.utils.data import Dataset, DataLoader
|
||
|
|
from datetime import datetime, timedelta
|
||
|
|
from pathlib import Path
|
||
|
|
from typing import List, Dict, Any, Optional, Tuple
|
||
|
|
import random
|
||
|
|
import numpy as np
|
||
|
|
|
||
|
|
from config import logger, cfg, sys_config
|
||
|
|
|
||
|
|
# ==================== TOKENIZER ====================
|
||
|
|
class Tokenizer:
|
||
|
|
def __init__(self):
|
||
|
|
self.vocab = cfg.vocab
|
||
|
|
self.vocab_size = cfg.vocab_size
|
||
|
|
self.char_to_idx = {ch: i for i, ch in enumerate(self.vocab)}
|
||
|
|
self.idx_to_char = {i: ch for i, ch in enumerate(self.vocab)}
|
||
|
|
|
||
|
|
def encode(self, text: str) -> List[int]:
|
||
|
|
"""Enkoduje tekst na listę indeksów"""
|
||
|
|
return [self.char_to_idx.get(ch, 0) for ch in text if ch in self.char_to_idx]
|
||
|
|
|
||
|
|
def decode(self, indices: List[int]) -> str:
|
||
|
|
"""Dekoduje listę indeksów na tekst"""
|
||
|
|
return ''.join([self.idx_to_char.get(idx, '') for idx in indices])
|
||
|
|
|
||
|
|
def encode_batch(self, texts: List[str]) -> torch.Tensor:
|
||
|
|
"""Enkoduje batch tekstów"""
|
||
|
|
encoded = [self.encode(text) for text in texts]
|
||
|
|
max_len = max(len(e) for e in encoded)
|
||
|
|
padded = [e + [0] * (max_len - len(e)) for e in encoded]
|
||
|
|
return torch.tensor(padded, dtype=torch.long)
|
||
|
|
|
||
|
|
def decode_batch(self, tensors: torch.Tensor) -> List[str]:
|
||
|
|
"""Dekoduje batch tensorów"""
|
||
|
|
texts = []
|
||
|
|
for tensor in tensors:
|
||
|
|
indices = tensor.tolist()
|
||
|
|
# Usuń padding (0)
|
||
|
|
indices = [idx for idx in indices if idx != 0]
|
||
|
|
texts.append(self.decode(indices))
|
||
|
|
return texts
|
||
|
|
|
||
|
|
# Globalny tokenizer
|
||
|
|
tokenizer = Tokenizer()
|
||
|
|
|
||
|
|
# ==================== DATASET ====================
|
||
|
|
class TextDataset(Dataset):
|
||
|
|
def __init__(self, filepath: str, max_len: int = 512):
|
||
|
|
self.filepath = filepath
|
||
|
|
self.max_len = max_len
|
||
|
|
|
||
|
|
# Wczytaj dane
|
||
|
|
with open(filepath, 'r', encoding='utf-8') as f:
|
||
|
|
self.lines = [line.strip() for line in f if line.strip()]
|
||
|
|
|
||
|
|
logger.info(f"📊 Dataset: {len(self.lines):,} linii")
|
||
|
|
|
||
|
|
def __len__(self):
|
||
|
|
return len(self.lines)
|
||
|
|
|
||
|
|
def __getitem__(self, idx):
|
||
|
|
text = self.lines[idx]
|
||
|
|
|
||
|
|
# Przycinaj lub paduj do max_len
|
||
|
|
if len(text) > self.max_len:
|
||
|
|
start = random.randint(0, len(text) - self.max_len)
|
||
|
|
text = text[start:start + self.max_len]
|
||
|
|
|
||
|
|
# Enkoduj
|
||
|
|
encoded = tokenizer.encode(text)
|
||
|
|
|
||
|
|
# Paduj jeśli za krótkie
|
||
|
|
if len(encoded) < self.max_len:
|
||
|
|
encoded = encoded + [0] * (self.max_len - len(encoded))
|
||
|
|
|
||
|
|
x = torch.tensor(encoded[:-1], dtype=torch.long)
|
||
|
|
y = torch.tensor(encoded[1:], dtype=torch.long)
|
||
|
|
|
||
|
|
return x, y
|
||
|
|
|
||
|
|
# ==================== CLAY CHECKPOINT SYSTEM ====================
|
||
|
|
class ClayCheckpoint:
|
||
|
|
"""Zaawansowany system checkpointów z Clay UI"""
|
||
|
|
|
||
|
|
def __init__(self, model, optimizer, scheduler=None):
|
||
|
|
self.model = model
|
||
|
|
self.optimizer = optimizer
|
||
|
|
self.scheduler = scheduler
|
||
|
|
self.checkpoint_dir = Path(cfg.checkpoints_dir)
|
||
|
|
self.checkpoint_dir.mkdir(exist_ok=True)
|
||
|
|
|
||
|
|
# Training stats
|
||
|
|
self.stats = {
|
||
|
|
'start_time': time.time(),
|
||
|
|
'epoch_times': [],
|
||
|
|
'step_times': [],
|
||
|
|
'loss_history': [],
|
||
|
|
'learning_rates': [],
|
||
|
|
'best_loss': float('inf')
|
||
|
|
}
|
||
|
|
|
||
|
|
# Progress tracking
|
||
|
|
self.progress = {
|
||
|
|
'current_step': 0,
|
||
|
|
'total_steps': 0,
|
||
|
|
'current_epoch': 0,
|
||
|
|
'total_epochs': cfg.epochs,
|
||
|
|
'estimated_completion': None
|
||
|
|
}
|
||
|
|
|
||
|
|
# Progress bar style
|
||
|
|
self.progress_style = {
|
||
|
|
'filled': '█',
|
||
|
|
'empty': '░',
|
||
|
|
'arrow': '▶',
|
||
|
|
'spinner': ['⠋', '⠙', '⠹', '⠸', '⠼', '⠴', '⠦', '⠧', '⠇', '⠏']
|
||
|
|
}
|
||
|
|
|
||
|
|
def save(self, step, epoch, loss, is_best=False, force=False):
|
||
|
|
"""Zapisuje checkpoint"""
|
||
|
|
if not force and step % cfg.checkpoint_freq != 0:
|
||
|
|
return
|
||
|
|
|
||
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
|
|
filename = f"clay_checkpoint_ep{epoch}_step{step}_{timestamp}.pt"
|
||
|
|
filepath = self.checkpoint_dir / filename
|
||
|
|
|
||
|
|
# Przygotuj stan
|
||
|
|
checkpoint = {
|
||
|
|
'step': step,
|
||
|
|
'epoch': epoch,
|
||
|
|
'model_state_dict': self.model.state_dict(),
|
||
|
|
'optimizer_state_dict': self.optimizer.state_dict(),
|
||
|
|
'loss': loss,
|
||
|
|
'stats': self.stats,
|
||
|
|
'progress': self.progress,
|
||
|
|
'timestamp': timestamp,
|
||
|
|
'config': {
|
||
|
|
'vocab_size': cfg.vocab_size,
|
||
|
|
'embed_dim': cfg.embed_dim,
|
||
|
|
'n_layers': cfg.n_layers,
|
||
|
|
'n_heads': cfg.n_heads
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
if self.scheduler:
|
||
|
|
checkpoint['scheduler_state_dict'] = self.scheduler.state_dict()
|
||
|
|
|
||
|
|
# Zapisz
|
||
|
|
torch.save(checkpoint, filepath)
|
||
|
|
|
||
|
|
# Zapisz jako JSON dla czytelności
|
||
|
|
json_path = filepath.with_suffix('.json')
|
||
|
|
json_data = {
|
||
|
|
'checkpoint_info': {
|
||
|
|
'filename': filename,
|
||
|
|
'step': step,
|
||
|
|
'epoch': epoch,
|
||
|
|
'loss': float(loss),
|
||
|
|
'timestamp': timestamp,
|
||
|
|
'file_size': os.path.getsize(filepath)
|
||
|
|
},
|
||
|
|
'training_stats': {
|
||
|
|
'total_time': time.time() - self.stats['start_time'],
|
||
|
|
'avg_loss': float(sum(self.stats['loss_history'][-100:]) / min(100, len(self.stats['loss_history']))),
|
||
|
|
'current_lr': float(self.optimizer.param_groups[0]['lr']),
|
||
|
|
'steps_done': self.progress['current_step']
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
with open(json_path, 'w', encoding='utf-8') as f:
|
||
|
|
json.dump(json_data, f, indent=2, ensure_ascii=False)
|
||
|
|
|
||
|
|
# Zachowaj tylko N ostatnich checkpointów
|
||
|
|
self.cleanup_old_checkpoints()
|
||
|
|
|
||
|
|
logger.info(f"💾 Clay Checkpoint: {filename} (loss: {loss:.4f})")
|
||
|
|
|
||
|
|
if is_best:
|
||
|
|
best_path = self.checkpoint_dir / "clay_best.pt"
|
||
|
|
torch.save(checkpoint, best_path)
|
||
|
|
logger.info(f"🏆 Nowy najlepszy model zapisany!")
|
||
|
|
|
||
|
|
def load_latest(self):
|
||
|
|
"""Wczytuje najnowszy checkpoint"""
|
||
|
|
checkpoints = list(self.checkpoint_dir.glob("clay_checkpoint_*.pt"))
|
||
|
|
if not checkpoints:
|
||
|
|
return None, None, None, None
|
||
|
|
|
||
|
|
# Znajdź najnowszy po czasie
|
||
|
|
latest = max(checkpoints, key=os.path.getmtime)
|
||
|
|
|
||
|
|
logger.info(f"🔄 Ładowanie Clay Checkpoint: {latest.name}")
|
||
|
|
checkpoint = torch.load(latest, map_location='cpu')
|
||
|
|
|
||
|
|
# Przywróć stan
|
||
|
|
self.model.load_state_dict(checkpoint['model_state_dict'])
|
||
|
|
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||
|
|
|
||
|
|
if 'scheduler_state_dict' in checkpoint and self.scheduler:
|
||
|
|
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
||
|
|
|
||
|
|
# Przywróć statystyki
|
||
|
|
if 'stats' in checkpoint:
|
||
|
|
self.stats.update(checkpoint['stats'])
|
||
|
|
|
||
|
|
if 'progress' in checkpoint:
|
||
|
|
self.progress.update(checkpoint['progress'])
|
||
|
|
|
||
|
|
logger.info(f"✅ Wczytano checkpoint: epoka {checkpoint['epoch']}, krok {checkpoint['step']}")
|
||
|
|
|
||
|
|
return checkpoint['epoch'], checkpoint['step'], checkpoint['loss'], checkpoint.get('timestamp')
|
||
|
|
|
||
|
|
def cleanup_old_checkpoints(self):
|
||
|
|
"""Czyści stare checkpointy"""
|
||
|
|
checkpoints = list(self.checkpoint_dir.glob("clay_checkpoint_*.pt"))
|
||
|
|
checkpoints.sort(key=os.path.getmtime)
|
||
|
|
|
||
|
|
keep = cfg.get('keep_checkpoints', 5)
|
||
|
|
while len(checkpoints) > keep:
|
||
|
|
old = checkpoints.pop(0)
|
||
|
|
# Zostaw najlepszy
|
||
|
|
if "best" not in old.name:
|
||
|
|
old.unlink()
|
||
|
|
# Usuń też JSON
|
||
|
|
json_file = old.with_suffix('.json')
|
||
|
|
if json_file.exists():
|
||
|
|
json_file.unlink()
|
||
|
|
|
||
|
|
def estimate_time_remaining(self):
|
||
|
|
"""Szacuje pozostały czas treningu"""
|
||
|
|
if not self.stats['step_times']:
|
||
|
|
return "Obliczanie..."
|
||
|
|
|
||
|
|
avg_time_per_step = sum(self.stats['step_times'][-100:]) / min(100, len(self.stats['step_times']))
|
||
|
|
steps_done = self.progress['current_step']
|
||
|
|
total_steps = self.progress['total_steps']
|
||
|
|
|
||
|
|
if steps_done == 0 or total_steps == 0:
|
||
|
|
return "Obliczanie..."
|
||
|
|
|
||
|
|
steps_left = total_steps - steps_done
|
||
|
|
seconds_left = steps_left * avg_time_per_step
|
||
|
|
|
||
|
|
# Jeśli mamy dane z epok, użyj ich
|
||
|
|
if self.stats['epoch_times']:
|
||
|
|
epochs_done = self.progress['current_epoch']
|
||
|
|
epochs_left = self.progress['total_epochs'] - epochs_done
|
||
|
|
avg_epoch_time = sum(self.stats['epoch_times']) / len(self.stats['epoch_times'])
|
||
|
|
epoch_based = epochs_left * avg_epoch_time
|
||
|
|
seconds_left = min(seconds_left, epoch_based)
|
||
|
|
|
||
|
|
return self._format_time(seconds_left)
|
||
|
|
|
||
|
|
def _format_time(self, seconds):
|
||
|
|
"""Formatuje czas"""
|
||
|
|
if seconds < 60:
|
||
|
|
return f"{seconds:.0f}s"
|
||
|
|
elif seconds < 3600:
|
||
|
|
return f"{seconds/60:.1f}m"
|
||
|
|
elif seconds < 86400:
|
||
|
|
return f"{seconds/3600:.1f}h"
|
||
|
|
else:
|
||
|
|
return f"{seconds/86400:.1f}d"
|
||
|
|
|
||
|
|
def get_progress_bar(self, width=40):
|
||
|
|
"""Generuje pasek postępu"""
|
||
|
|
if self.progress['total_steps'] == 0:
|
||
|
|
return "[░░░░░░░░░░░░░░░░░░░░]"
|
||
|
|
|
||
|
|
progress = min(1.0, self.progress['current_step'] / self.progress['total_steps'])
|
||
|
|
filled = int(width * progress)
|
||
|
|
empty = width - filled
|
||
|
|
|
||
|
|
bar = self.progress_style['filled'] * filled
|
||
|
|
if filled < width:
|
||
|
|
bar += self.progress_style['arrow']
|
||
|
|
bar += self.progress_style['empty'] * empty
|
||
|
|
|
||
|
|
return f"[{bar}]"
|
||
|
|
|
||
|
|
def print_progress(self, step_loss, current_lr):
|
||
|
|
"""Wyświetla aktualny postęp"""
|
||
|
|
progress_bar = self.get_progress_bar(40)
|
||
|
|
|
||
|
|
# Oblicz procent
|
||
|
|
percent = (self.progress['current_step'] / self.progress['total_steps']) * 100
|
||
|
|
|
||
|
|
# Szacowany czas
|
||
|
|
eta = self.estimate_time_remaining()
|
||
|
|
|
||
|
|
# Spinner (animowany)
|
||
|
|
spinner_idx = int(time.time() * 4) % len(self.progress_style['spinner'])
|
||
|
|
spinner = self.progress_style['spinner'][spinner_idx]
|
||
|
|
|
||
|
|
# Formatuj wyjście
|
||
|
|
output = (f"\r{spinner} {progress_bar} {percent:5.1f}% | "
|
||
|
|
f"Step: {self.progress['current_step']:,}/{self.progress['total_steps']:,} | "
|
||
|
|
f"Loss: {step_loss:.4f} | LR: {current_lr:.6f} | "
|
||
|
|
f"ETA: {eta}")
|
||
|
|
|
||
|
|
print(output, end='', flush=True)
|
||
|
|
|
||
|
|
# Na końcu epoki, przejdź do nowej linii
|
||
|
|
if self.progress['current_step'] == self.progress['total_steps']:
|
||
|
|
print()
|
||
|
|
|
||
|
|
def start_epoch(self, epoch, total_batches):
|
||
|
|
"""Rozpoczyna nową epokę"""
|
||
|
|
self.progress['current_epoch'] = epoch
|
||
|
|
self.progress['total_steps'] = total_batches * self.progress['total_epochs']
|
||
|
|
self.epoch_start_time = time.time()
|
||
|
|
|
||
|
|
logger.info(f"\n🚀 ROZPOCZYNAM EPOKĘ {epoch+1}/{self.progress['total_epochs']}")
|
||
|
|
logger.info(f" • Kroki w epoce: {total_batches:,}")
|
||
|
|
logger.info(f" • Całkowite kroki: {self.progress['total_steps']:,}")
|
||
|
|
|
||
|
|
# Szacowanie czasu dla tej epoki
|
||
|
|
if self.stats['epoch_times']:
|
||
|
|
avg_epoch_time = sum(self.stats['epoch_times']) / len(self.stats['epoch_times'])
|
||
|
|
logger.info(f" • Szacowany czas epoki: {self._format_time(avg_epoch_time)}")
|
||
|
|
logger.info(f" • Szacowany czas do końca: {self.estimate_time_remaining()}")
|
||
|
|
|
||
|
|
def end_epoch(self, epoch_loss):
|
||
|
|
"""Kończy epokę"""
|
||
|
|
epoch_time = time.time() - self.epoch_start_time
|
||
|
|
self.stats['epoch_times'].append(epoch_time)
|
||
|
|
|
||
|
|
# Aktualizuj najlepszy loss
|
||
|
|
if epoch_loss < self.stats['best_loss']:
|
||
|
|
self.stats['best_loss'] = epoch_loss
|
||
|
|
|
||
|
|
# Oblicz pozostały czas
|
||
|
|
avg_epoch_time = sum(self.stats['epoch_times']) / len(self.stats['epoch_times'])
|
||
|
|
epochs_left = self.progress['total_epochs'] - self.progress['current_epoch'] - 1
|
||
|
|
total_time_left = avg_epoch_time * epochs_left
|
||
|
|
|
||
|
|
# Podsumowanie epoki
|
||
|
|
logger.info(f"📊 EPOKA {self.progress['current_epoch']+1} ZAKOŃCZONA:")
|
||
|
|
logger.info(f" • Loss: {epoch_loss:.4f}")
|
||
|
|
logger.info(f" • Najlepszy loss: {self.stats['best_loss']:.4f}")
|
||
|
|
logger.info(f" • Czas epoki: {self._format_time(epoch_time)}")
|
||
|
|
logger.info(f" • Średni czas/epokę: {self._format_time(avg_epoch_time)}")
|
||
|
|
logger.info(f" • Pozostało: ~{self._format_time(total_time_left)}")
|
||
|
|
|
||
|
|
# Przewidywane zakończenie
|
||
|
|
if epochs_left > 0:
|
||
|
|
eta_time = datetime.now() + timedelta(seconds=total_time_left)
|
||
|
|
logger.info(f" • Przewidywane zakończenie: {eta_time.strftime('%Y-%m-%d %H:%M:%S')}")
|
||
|
|
|
||
|
|
# ==================== MODEL ARCHITECTURE ====================
|
||
|
|
class AttentionHead(nn.Module):
|
||
|
|
def __init__(self, embed_dim, head_dim, dropout=0.1):
|
||
|
|
super().__init__()
|
||
|
|
self.head_dim = head_dim
|
||
|
|
self.scale = head_dim ** -0.5
|
||
|
|
|
||
|
|
self.qkv = nn.Linear(embed_dim, 3 * head_dim)
|
||
|
|
self.proj = nn.Linear(head_dim, embed_dim)
|
||
|
|
self.dropout = nn.Dropout(dropout)
|
||
|
|
|
||
|
|
def forward(self, x, mask=None):
|
||
|
|
B, T, C = x.shape
|
||
|
|
qkv = self.qkv(x)
|
||
|
|
q, k, v = qkv.chunk(3, dim=-1)
|
||
|
|
|
||
|
|
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||
|
|
|
||
|
|
if mask is not None:
|
||
|
|
attn = attn.masked_fill(mask == 0, float('-inf'))
|
||
|
|
|
||
|
|
attn = attn.softmax(dim=-1)
|
||
|
|
attn = self.dropout(attn)
|
||
|
|
|
||
|
|
out = attn @ v
|
||
|
|
out = self.proj(out)
|
||
|
|
return out
|
||
|
|
|
||
|
|
class MultiHeadAttention(nn.Module):
|
||
|
|
def __init__(self, embed_dim, num_heads, dropout=0.1):
|
||
|
|
super().__init__()
|
||
|
|
assert embed_dim % num_heads == 0
|
||
|
|
self.head_dim = embed_dim // num_heads
|
||
|
|
self.num_heads = num_heads
|
||
|
|
|
||
|
|
self.heads = nn.ModuleList([
|
||
|
|
AttentionHead(embed_dim, self.head_dim, dropout)
|
||
|
|
for _ in range(num_heads)
|
||
|
|
])
|
||
|
|
self.proj = nn.Linear(embed_dim, embed_dim)
|
||
|
|
self.dropout = nn.Dropout(dropout)
|
||
|
|
|
||
|
|
def forward(self, x, mask=None):
|
||
|
|
# Równoległe przetwarzanie przez wszystkie głowy
|
||
|
|
head_outputs = [head(x, mask) for head in self.heads]
|
||
|
|
|
||
|
|
# Konkatenacja
|
||
|
|
out = torch.cat(head_outputs, dim=-1)
|
||
|
|
out = self.proj(out)
|
||
|
|
out = self.dropout(out)
|
||
|
|
return out
|
||
|
|
|
||
|
|
class FeedForward(nn.Module):
|
||
|
|
def __init__(self, embed_dim, ff_dim, dropout=0.1):
|
||
|
|
super().__init__()
|
||
|
|
self.net = nn.Sequential(
|
||
|
|
nn.Linear(embed_dim, ff_dim),
|
||
|
|
nn.GELU(),
|
||
|
|
nn.Dropout(dropout),
|
||
|
|
nn.Linear(ff_dim, embed_dim),
|
||
|
|
nn.Dropout(dropout)
|
||
|
|
)
|
||
|
|
|
||
|
|
def forward(self, x):
|
||
|
|
return self.net(x)
|
||
|
|
|
||
|
|
class TransformerBlock(nn.Module):
|
||
|
|
def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1):
|
||
|
|
super().__init__()
|
||
|
|
self.attention = MultiHeadAttention(embed_dim, num_heads, dropout)
|
||
|
|
self.ff = FeedForward(embed_dim, ff_dim, dropout)
|
||
|
|
self.norm1 = nn.LayerNorm(embed_dim)
|
||
|
|
self.norm2 = nn.LayerNorm(embed_dim)
|
||
|
|
|
||
|
|
def forward(self, x, mask=None):
|
||
|
|
# Self-attention z residual connection
|
||
|
|
attn_out = self.attention(self.norm1(x), mask)
|
||
|
|
x = x + attn_out
|
||
|
|
|
||
|
|
# Feed forward z residual connection
|
||
|
|
ff_out = self.ff(self.norm2(x))
|
||
|
|
x = x + ff_out
|
||
|
|
|
||
|
|
return x
|
||
|
|
|
||
|
|
class MiniGPT60M(nn.Module):
|
||
|
|
def __init__(self):
|
||
|
|
super().__init__()
|
||
|
|
|
||
|
|
# Embeddings
|
||
|
|
self.token_embedding = nn.Embedding(cfg.vocab_size, cfg.embed_dim)
|
||
|
|
self.position_embedding = nn.Embedding(cfg.max_len, cfg.embed_dim)
|
||
|
|
|
||
|
|
# Transformer blocks
|
||
|
|
self.blocks = nn.ModuleList([
|
||
|
|
TransformerBlock(cfg.embed_dim, cfg.n_heads, cfg.ff_dim, cfg.dropout)
|
||
|
|
for _ in range(cfg.n_layers)
|
||
|
|
])
|
||
|
|
|
||
|
|
# Final layers
|
||
|
|
self.norm = nn.LayerNorm(cfg.embed_dim)
|
||
|
|
self.head = nn.Linear(cfg.embed_dim, cfg.vocab_size)
|
||
|
|
|
||
|
|
# Initialize weights
|
||
|
|
self.apply(self._init_weights)
|
||
|
|
|
||
|
|
# Count parameters
|
||
|
|
total_params = sum(p.numel() for p in self.parameters())
|
||
|
|
trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
||
|
|
|
||
|
|
logger.info(f"🤖 Model MiniGPT-60M stworzony")
|
||
|
|
logger.info(f" • Parametry: {total_params:,} (trainable: {trainable_params:,})")
|
||
|
|
logger.info(f" • Embed dim: {cfg.embed_dim}")
|
||
|
|
logger.info(f" • Warstwy: {cfg.n_layers}")
|
||
|
|
logger.info(f" • Głowy: {cfg.n_heads}")
|
||
|
|
|
||
|
|
def _init_weights(self, module):
|
||
|
|
if isinstance(module, nn.Linear):
|
||
|
|
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||
|
|
if module.bias is not None:
|
||
|
|
nn.init.zeros_(module.bias)
|
||
|
|
elif isinstance(module, nn.Embedding):
|
||
|
|
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||
|
|
|
||
|
|
def forward(self, idx, targets=None):
|
||
|
|
B, T = idx.shape
|
||
|
|
|
||
|
|
# Embeddings
|
||
|
|
tok_emb = self.token_embedding(idx)
|
||
|
|
pos = torch.arange(T, device=idx.device).unsqueeze(0)
|
||
|
|
pos_emb = self.position_embedding(pos)
|
||
|
|
x = tok_emb + pos_emb
|
||
|
|
|
||
|
|
# Causal mask
|
||
|
|
mask = torch.tril(torch.ones(T, T, device=idx.device)).view(1, 1, T, T)
|
||
|
|
|
||
|
|
# Transformer blocks
|
||
|
|
for block in self.blocks:
|
||
|
|
x = block(x, mask)
|
||
|
|
|
||
|
|
# Final layer
|
||
|
|
x = self.norm(x)
|
||
|
|
logits = self.head(x)
|
||
|
|
|
||
|
|
# Loss if targets provided
|
||
|
|
loss = None
|
||
|
|
if targets is not None:
|
||
|
|
loss = nn.functional.cross_entropy(
|
||
|
|
logits.view(-1, logits.size(-1)),
|
||
|
|
targets.view(-1)
|
||
|
|
)
|
||
|
|
|
||
|
|
return logits, loss
|
||
|
|
|
||
|
|
def generate(self, input_ids=None, max_length=100, temperature=1.0, **kwargs):
|
||
|
|
"""Uniwersalna metoda generate obsługująca różne interfejsy"""
|
||
|
|
# Ignoruj nieobsługiwane argumenty (device, max_len, etc.)
|
||
|
|
max_len = kwargs.get('max_len', max_length)
|
||
|
|
|
||
|
|
if isinstance(input_ids, str):
|
||
|
|
return self.generate_text(input_ids, max_len=max_len, temperature=temperature)
|
||
|
|
elif input_ids is not None:
|
||
|
|
# Jeśli to tensor, przekonwertuj na tekst
|
||
|
|
if hasattr(input_ids, 'tolist'):
|
||
|
|
if len(input_ids.shape) > 1:
|
||
|
|
text = tokenizer.decode(input_ids[0].tolist())
|
||
|
|
else:
|
||
|
|
text = tokenizer.decode(input_ids.tolist())
|
||
|
|
else:
|
||
|
|
text = tokenizer.decode(input_ids)
|
||
|
|
return self.generate_text(text, max_len=max_len, temperature=temperature)
|
||
|
|
else:
|
||
|
|
return "" # Pusty string dla None input
|
||
|
|
|
||
|
|
def generate_text(self, prompt, max_len=100, temperature=1.0):
|
||
|
|
"""Generuje tekst na podstawie promptu"""
|
||
|
|
self.eval()
|
||
|
|
|
||
|
|
# Encode prompt
|
||
|
|
input_ids = tokenizer.encode(prompt)
|
||
|
|
if len(input_ids) == 0:
|
||
|
|
return prompt
|
||
|
|
|
||
|
|
# Przygotuj tensor
|
||
|
|
input_tensor = torch.tensor([input_ids], dtype=torch.long)
|
||
|
|
if torch.cuda.is_available():
|
||
|
|
input_tensor = input_tensor.cuda()
|
||
|
|
|
||
|
|
generated = input_ids.copy()
|
||
|
|
|
||
|
|
with torch.no_grad():
|
||
|
|
for _ in range(max_len):
|
||
|
|
# Forward pass
|
||
|
|
logits, _ = self(input_tensor)
|
||
|
|
|
||
|
|
# Weź logity dla ostatniego tokena
|
||
|
|
next_token_logits = logits[0, -1, :] / temperature
|
||
|
|
|
||
|
|
# Top-k sampling
|
||
|
|
if cfg.top_k > 0:
|
||
|
|
indices_to_remove = next_token_logits < torch.topk(next_token_logits, cfg.top_k)[0][..., -1, None]
|
||
|
|
next_token_logits[indices_to_remove] = float('-inf')
|
||
|
|
|
||
|
|
# Top-p sampling
|
||
|
|
if cfg.top_p < 1.0:
|
||
|
|
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
|
||
|
|
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
||
|
|
|
||
|
|
sorted_indices_to_remove = cumulative_probs > cfg.top_p
|
||
|
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||
|
|
sorted_indices_to_remove[..., 0] = 0
|
||
|
|
|
||
|
|
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
||
|
|
next_token_logits[indices_to_remove] = float('-inf')
|
||
|
|
|
||
|
|
# Sample next token
|
||
|
|
probs = torch.softmax(next_token_logits, dim=-1)
|
||
|
|
next_token = torch.multinomial(probs, num_samples=1).item()
|
||
|
|
|
||
|
|
# Dodaj do sekwencji
|
||
|
|
generated.append(next_token)
|
||
|
|
|
||
|
|
# Aktualizuj input
|
||
|
|
input_tensor = torch.tensor([generated[-cfg.max_len:]], dtype=torch.long)
|
||
|
|
if torch.cuda.is_available():
|
||
|
|
input_tensor = input_tensor.cuda()
|
||
|
|
|
||
|
|
# Sprawdź token kończący
|
||
|
|
if next_token == tokenizer.char_to_idx.get('\n', -1):
|
||
|
|
break
|
||
|
|
|
||
|
|
# Decode
|
||
|
|
generated_text = tokenizer.decode(generated)
|
||
|
|
return generated_text
|
||
|
|
|
||
|
|
# ==================== TRAINING FUNCTIONS ====================
|
||
|
|
def create_optimizer(model, learning_rate=3e-4, weight_decay=0.1):
|
||
|
|
"""Tworzy optimizer z dekayem wag"""
|
||
|
|
decay_params = []
|
||
|
|
no_decay_params = []
|
||
|
|
|
||
|
|
for name, param in model.named_parameters():
|
||
|
|
if not param.requires_grad:
|
||
|
|
continue
|
||
|
|
|
||
|
|
if 'weight' in name and len(param.shape) > 1:
|
||
|
|
decay_params.append(param)
|
||
|
|
else:
|
||
|
|
no_decay_params.append(param)
|
||
|
|
|
||
|
|
optimizer = optim.AdamW([
|
||
|
|
{'params': decay_params, 'weight_decay': weight_decay},
|
||
|
|
{'params': no_decay_params, 'weight_decay': 0.0}
|
||
|
|
], lr=learning_rate, betas=(cfg.adam_beta1, cfg.adam_beta2), eps=cfg.adam_eps)
|
||
|
|
|
||
|
|
return optimizer
|
||
|
|
|
||
|
|
def create_scheduler(optimizer, warmup_steps=2000):
|
||
|
|
"""Tworzy scheduler z warmupem"""
|
||
|
|
scheduler = optim.lr_scheduler.OneCycleLR(
|
||
|
|
optimizer,
|
||
|
|
max_lr=cfg.learning_rate,
|
||
|
|
total_steps=cfg.epochs * 1000, # Szacunkowo
|
||
|
|
pct_start=warmup_steps / (cfg.epochs * 1000),
|
||
|
|
anneal_strategy='cos'
|
||
|
|
)
|
||
|
|
return scheduler
|
||
|
|
|
||
|
|
def benchmark_device():
|
||
|
|
"""Testuje wydajność urządzenia"""
|
||
|
|
logger.info("🏃♂️ BENCHMARK URZĄDZENIA:")
|
||
|
|
|
||
|
|
if torch.cuda.is_available():
|
||
|
|
device_name = torch.cuda.get_device_name(0)
|
||
|
|
memory_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
|
||
|
|
|
||
|
|
logger.info(f" • GPU: {device_name}")
|
||
|
|
logger.info(f" • Pamięć GPU: {memory_gb:.1f} GB")
|
||
|
|
|
||
|
|
# Test wydajności GPU
|
||
|
|
start = time.time()
|
||
|
|
x = torch.randn(1024, 768, device='cuda')
|
||
|
|
y = torch.randn(768, 512, device='cuda')
|
||
|
|
for _ in range(100):
|
||
|
|
_ = torch.matmul(x, y)
|
||
|
|
torch.cuda.synchronize()
|
||
|
|
gpu_time = time.time() - start
|
||
|
|
|
||
|
|
logger.info(f" • Wydajność GPU: {gpu_time:.2f}s na 100 operacji")
|
||
|
|
|
||
|
|
# Test CPU
|
||
|
|
x = torch.randn(1024, 768)
|
||
|
|
y = torch.randn(768, 512)
|
||
|
|
start = time.time()
|
||
|
|
for _ in range(100):
|
||
|
|
_ = torch.matmul(x, y)
|
||
|
|
cpu_time = time.time() - start
|
||
|
|
|
||
|
|
logger.info(f" • Wydajność CPU: {cpu_time:.2f}s na 100 operacji")
|
||
|
|
|
||
|
|
if torch.cuda.is_available():
|
||
|
|
speedup = cpu_time / gpu_time
|
||
|
|
logger.info(f" • Przyśpieszenie GPU vs CPU: {speedup:.1f}x")
|
||
|
|
|
||
|
|
def estimate_training_time(total_samples, batch_size, epochs, device):
|
||
|
|
"""Szacuje czas treningu przed rozpoczęciem"""
|
||
|
|
steps_per_epoch = math.ceil(total_samples / batch_size)
|
||
|
|
total_steps = steps_per_epoch * epochs
|
||
|
|
|
||
|
|
logger.info("⏱️ SZACOWANIE CZASU TRENINGU:")
|
||
|
|
logger.info(f" • Próbki: {total_samples:,}")
|
||
|
|
logger.info(f" • Epoki: {epochs}")
|
||
|
|
logger.info(f" • Batch size: {batch_size}")
|
||
|
|
logger.info(f" • Kroki: {total_steps:,}")
|
||
|
|
|
||
|
|
# Benchmarki dla różnych urządzeń (sekundy na 1000 kroków)
|
||
|
|
benchmarks = {
|
||
|
|
'cuda': {'T4': 120, 'P100': 90, 'V100': 60, 'A100': 30, 'default': 100},
|
||
|
|
'mps': {'M1': 150, 'M2': 100, 'default': 120},
|
||
|
|
'cpu': {'default': 600}
|
||
|
|
}
|
||
|
|
|
||
|
|
steps_in_k = total_steps / 1000
|
||
|
|
|
||
|
|
if device == 'cuda':
|
||
|
|
try:
|
||
|
|
gpu_name = torch.cuda.get_device_name(0)
|
||
|
|
if 'T4' in gpu_name:
|
||
|
|
time_per_k = benchmarks['cuda']['T4']
|
||
|
|
elif 'P100' in gpu_name:
|
||
|
|
time_per_k = benchmarks['cuda']['P100']
|
||
|
|
elif 'V100' in gpu_name:
|
||
|
|
time_per_k = benchmarks['cuda']['V100']
|
||
|
|
elif 'A100' in gpu_name:
|
||
|
|
time_per_k = benchmarks['cuda']['A100']
|
||
|
|
else:
|
||
|
|
time_per_k = benchmarks['cuda']['default']
|
||
|
|
except:
|
||
|
|
time_per_k = benchmarks['cuda']['default']
|
||
|
|
|
||
|
|
elif device == 'mps':
|
||
|
|
time_per_k = benchmarks['mps']['default']
|
||
|
|
else:
|
||
|
|
time_per_k = benchmarks['cpu']['default']
|
||
|
|
|
||
|
|
estimated_total = steps_in_k * time_per_k
|
||
|
|
|
||
|
|
# Formatuj czas
|
||
|
|
if estimated_total > 3600 * 24:
|
||
|
|
days = estimated_total / (3600 * 24)
|
||
|
|
time_str = f"~{days:.1f} dni"
|
||
|
|
elif estimated_total > 3600:
|
||
|
|
hours = estimated_total / 3600
|
||
|
|
time_str = f"~{hours:.1f} godzin"
|
||
|
|
elif estimated_total > 60:
|
||
|
|
minutes = estimated_total / 60
|
||
|
|
time_str = f"~{minutes:.1f} minut"
|
||
|
|
else:
|
||
|
|
time_str = f"~{estimated_total:.0f} sekund"
|
||
|
|
|
||
|
|
logger.info(f" • Szacowany czas: {time_str}")
|
||
|
|
|
||
|
|
# Przewidywane zakończenie
|
||
|
|
eta = datetime.now() + timedelta(seconds=estimated_total)
|
||
|
|
logger.info(f" • Przewidywane zakończenie: {eta.strftime('%Y-%m-%d %H:%M:%S')}")
|
||
|
|
|
||
|
|
return total_steps, estimated_total
|
||
|
|
|
||
|
|
def train_model(resume=False):
|
||
|
|
"""Główna funkcja treningowa"""
|
||
|
|
logger.info("=" * 60)
|
||
|
|
logger.info("🚀 ROZPOCZĘCIE TRENINGU MINIGPT-60M")
|
||
|
|
logger.info("=" * 60)
|
||
|
|
|
||
|
|
# Setup urządzenia
|
||
|
|
device = sys_config.device
|
||
|
|
if device == 'cuda' and not torch.cuda.is_available():
|
||
|
|
logger.warning("⚠ CUDA niedostępne, używam CPU")
|
||
|
|
device = 'cpu'
|
||
|
|
|
||
|
|
logger.info(f"🖥️ Urządzenie: {device.upper()}")
|
||
|
|
|
||
|
|
# Benchmark
|
||
|
|
benchmark_device()
|
||
|
|
|
||
|
|
# Wczytaj dane
|
||
|
|
data_file = Path(cfg.prepared_dir) / "all_data.txt"
|
||
|
|
if not data_file.exists():
|
||
|
|
logger.error(f"❌ Brak danych: {data_file}")
|
||
|
|
logger.info("💡 Uruchom: python main.py --prepare")
|
||
|
|
return
|
||
|
|
|
||
|
|
dataset = TextDataset(str(data_file), max_len=cfg.max_len)
|
||
|
|
train_loader = DataLoader(
|
||
|
|
dataset,
|
||
|
|
batch_size=cfg.batch_size,
|
||
|
|
shuffle=True,
|
||
|
|
num_workers=cfg.num_workers,
|
||
|
|
pin_memory=cfg.pin_memory
|
||
|
|
)
|
||
|
|
|
||
|
|
# Stwórz model
|
||
|
|
model = MiniGPT60M()
|
||
|
|
model.to(device)
|
||
|
|
|
||
|
|
# Loss function
|
||
|
|
criterion = nn.CrossEntropyLoss()
|
||
|
|
|
||
|
|
# Optimizer & scheduler
|
||
|
|
optimizer = create_optimizer(model, cfg.learning_rate, cfg.weight_decay)
|
||
|
|
scheduler = create_scheduler(optimizer, cfg.warmup_steps)
|
||
|
|
|
||
|
|
# Clay Checkpoint System
|
||
|
|
clay = ClayCheckpoint(model, optimizer, scheduler)
|
||
|
|
|
||
|
|
# Resume jeśli potrzebne
|
||
|
|
start_epoch = 0
|
||
|
|
start_step = 0
|
||
|
|
|
||
|
|
if resume:
|
||
|
|
loaded_epoch, loaded_step, loaded_loss, timestamp = clay.load_latest()
|
||
|
|
if loaded_epoch is not None:
|
||
|
|
start_epoch = loaded_epoch
|
||
|
|
start_step = loaded_step
|
||
|
|
logger.info(f"🔄 Wznawianie z epoki {start_epoch}, kroku {start_step}")
|
||
|
|
logger.info(f" • Ostatni loss: {loaded_loss:.4f}")
|
||
|
|
logger.info(f" • Timestamp: {timestamp}")
|
||
|
|
else:
|
||
|
|
logger.warning("⚠ Nie znaleziono checkpointu, zaczynam od początku")
|
||
|
|
|
||
|
|
# Szacowanie czasu
|
||
|
|
total_steps, estimated_total = estimate_training_time(
|
||
|
|
len(dataset), cfg.batch_size, cfg.epochs - start_epoch, device
|
||
|
|
)
|
||
|
|
|
||
|
|
logger.info("=" * 60)
|
||
|
|
logger.info("🎬 ROZPOCZĘTO TRENING!")
|
||
|
|
logger.info("=" * 60)
|
||
|
|
|
||
|
|
# Główna pętla treningowa
|
||
|
|
best_loss = float('inf')
|
||
|
|
|
||
|
|
for epoch in range(start_epoch, cfg.epochs):
|
||
|
|
clay.start_epoch(epoch, len(train_loader))
|
||
|
|
epoch_loss = 0
|
||
|
|
|
||
|
|
model.train()
|
||
|
|
|
||
|
|
for batch_idx, (x, y) in enumerate(train_loader):
|
||
|
|
step_start_time = time.time()
|
||
|
|
|
||
|
|
x, y = x.to(device), y.to(device)
|
||
|
|
|
||
|
|
# Forward pass
|
||
|
|
logits, loss = model(x, y)
|
||
|
|
|
||
|
|
# Backward pass
|
||
|
|
optimizer.zero_grad()
|
||
|
|
loss.backward()
|
||
|
|
|
||
|
|
# Gradient clipping
|
||
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.clip_grad)
|
||
|
|
|
||
|
|
optimizer.step()
|
||
|
|
|
||
|
|
if scheduler:
|
||
|
|
scheduler.step()
|
||
|
|
|
||
|
|
# Aktualizuj statystyki
|
||
|
|
step_time = time.time() - step_start_time
|
||
|
|
clay.stats['step_times'].append(step_time)
|
||
|
|
clay.stats['loss_history'].append(loss.item())
|
||
|
|
clay.stats['learning_rates'].append(optimizer.param_groups[0]['lr'])
|
||
|
|
|
||
|
|
# Aktualizuj progress
|
||
|
|
clay.progress['current_step'] = epoch * len(train_loader) + batch_idx
|
||
|
|
epoch_loss += loss.item()
|
||
|
|
|
||
|
|
# Wyświetl progress co 10 kroków lub 1% batchy
|
||
|
|
if batch_idx % max(1, len(train_loader) // 100) == 0 or batch_idx < 10:
|
||
|
|
clay.print_progress(loss.item(), optimizer.param_groups[0]['lr'])
|
||
|
|
|
||
|
|
# Zapisz checkpoint co 1000 kroków
|
||
|
|
if clay.progress['current_step'] % 1000 == 0 and clay.progress['current_step'] > 0:
|
||
|
|
is_best = loss.item() < best_loss
|
||
|
|
if is_best:
|
||
|
|
best_loss = loss.item()
|
||
|
|
|
||
|
|
clay.save(
|
||
|
|
clay.progress['current_step'],
|
||
|
|
epoch,
|
||
|
|
loss.item(),
|
||
|
|
is_best=is_best
|
||
|
|
)
|
||
|
|
|
||
|
|
# Koniec epoki
|
||
|
|
avg_epoch_loss = epoch_loss / len(train_loader)
|
||
|
|
clay.end_epoch(avg_epoch_loss)
|
||
|
|
|
||
|
|
# Zapisz model co epokę
|
||
|
|
model_path = Path(cfg.model_dir) / f"minigpt_epoch_{epoch + 1}.pt"
|
||
|
|
torch.save({
|
||
|
|
'epoch': epoch,
|
||
|
|
'model_state_dict': model.state_dict(),
|
||
|
|
'optimizer_state_dict': optimizer.state_dict(),
|
||
|
|
'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
|
||
|
|
'loss': avg_epoch_loss,
|
||
|
|
'config': cfg.__dict__
|
||
|
|
}, model_path)
|
||
|
|
|
||
|
|
logger.info(f"💾 Model zapisany: {model_path.name}")
|
||
|
|
|
||
|
|
# Zapisz też jako najlepszy jeśli to najlepszy loss
|
||
|
|
if avg_epoch_loss < best_loss:
|
||
|
|
best_loss = avg_epoch_loss
|
||
|
|
best_path = Path(cfg.model_dir) / "minigpt_best.pt"
|
||
|
|
torch.save(model.state_dict(), best_path)
|
||
|
|
logger.info(f"🏆 Nowy najlepszy model! Loss: {avg_epoch_loss:.4f}")
|
||
|
|
|
||
|
|
# ==================== KONIEC TRENINGU ====================
|
||
|
|
total_time = time.time() - clay.stats['start_time']
|
||
|
|
|
||
|
|
logger.info("=" * 60)
|
||
|
|
logger.info("🎉 TRENING ZAKOŃCZONY!")
|
||
|
|
logger.info("=" * 60)
|
||
|
|
|
||
|
|
# Podsumowanie statystyk
|
||
|
|
logger.info(f"📊 PODSUMOWANIE:")
|
||
|
|
logger.info(f" • Całkowity czas: {clay._format_time(total_time)}")
|
||
|
|
logger.info(f" • Ostatni loss: {clay.stats['loss_history'][-1]:.4f}")
|
||
|
|
logger.info(f" • Najlepszy loss: {clay.stats['best_loss']:.4f}")
|
||
|
|
logger.info(f" • Średni loss: {sum(clay.stats['loss_history']) / len(clay.stats['loss_history']):.4f}")
|
||
|
|
logger.info(f" • Średni czas/krok: {sum(clay.stats['step_times']) / len(clay.stats['step_times']):.3f}s")
|
||
|
|
logger.info(f" • Średni czas/epokę: {sum(clay.stats['epoch_times']) / len(clay.stats['epoch_times']):.1f}s")
|
||
|
|
logger.info(f" • Całkowite kroki: {clay.progress['current_step']:,}")
|
||
|
|
|
||
|
|
# Zapisz finalne statystyki
|
||
|
|
stats_path = Path(cfg.checkpoints_dir) / "training_stats.json"
|
||
|
|
final_stats = {
|
||
|
|
'total_time': total_time,
|
||
|
|
'best_loss': float(clay.stats['best_loss']),
|
||
|
|
'final_loss': float(clay.stats['loss_history'][-1]),
|
||
|
|
'avg_loss': float(sum(clay.stats['loss_history']) / len(clay.stats['loss_history'])),
|
||
|
|
'total_steps': int(clay.progress['current_step']),
|
||
|
|
'total_epochs': int(cfg.epochs),
|
||
|
|
'learning_rate_history': [float(lr) for lr in clay.stats['learning_rates']],
|
||
|
|
'loss_history': [float(loss) for loss in clay.stats['loss_history']],
|
||
|
|
'completion_time': datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||
|
|
}
|
||
|
|
|
||
|
|
with open(stats_path, 'w', encoding='utf-8') as f:
|
||
|
|
json.dump(final_stats, f, indent=2, ensure_ascii=False)
|
||
|
|
|
||
|
|
logger.info(f"📈 Statystyki zapisane: {stats_path}")
|
||
|
|
|
||
|
|
# Zapisz finalny model
|
||
|
|
final_path = Path(cfg.model_dir) / "minigpt_final.pt"
|
||
|
|
torch.save({
|
||
|
|
'model_state_dict': model.state_dict(),
|
||
|
|
'config': cfg.__dict__,
|
||
|
|
'stats': final_stats
|
||
|
|
}, final_path)
|
||
|
|
|
||
|
|
logger.info(f"💾 Finalny model zapisany: {final_path}")
|
||
|
|
|
||
|
|
# Test generowania
|
||
|
|
logger.info(f"\n🧪 TEST GENEROWANIA:")
|
||
|
|
test_prompts = [
|
||
|
|
"Witaj, ",
|
||
|
|
"Python to ",
|
||
|
|
"Dzisiaj jest ",
|
||
|
|
"AI to "
|
||
|
|
]
|
||
|
|
|
||
|
|
for prompt in test_prompts:
|
||
|
|
generated = model.generate_text(prompt, max_len=50, temperature=0.8)
|
||
|
|
logger.info(f" • '{prompt}' -> '{generated[:50]}...'")
|
||
|
|
|
||
|
|
logger.info("=" * 60)
|
||
|
|
logger.info("✅ TRENING ZAKOŃCZONY POMYŚLNIE!")
|
||
|
|
logger.info("=" * 60)
|
||
|
|
|
||
|
|
|
||
|
|
def continue_training():
|
||
|
|
"""Kontynuuje trening od ostatniego checkpointu"""
|
||
|
|
logger.info("🔄 KONTYNUUJĘ TRENING OD CHECKPOINTU")
|
||
|
|
train_model(resume=True)
|
||
|
|
|
||
|
|
|
||
|
|
def train_more_epochs(additional_epochs=3):
|
||
|
|
"""Dodaje więcej epok treningu"""
|
||
|
|
logger.info(f"📈 DODAJĘ {additional_epochs} EPOK TRENINGU")
|
||
|
|
|
||
|
|
# Zaktualizuj liczbę epok w konfiguracji
|
||
|
|
original_epochs = cfg.epochs
|
||
|
|
cfg.epochs = original_epochs + additional_epochs
|
||
|
|
|
||
|
|
logger.info(f" • Obecne epoki: {original_epochs}")
|
||
|
|
logger.info(f" • Nowe epoki: {cfg.epochs}")
|
||
|
|
logger.info(f" • Dodaję: {additional_epochs} epok")
|
||
|
|
|
||
|
|
# Kontynuuj trening
|
||
|
|
train_model(resume=True)
|
||
|
|
|
||
|
|
# Przywróć oryginalną wartość
|
||
|
|
cfg.epochs = original_epochs
|
||
|
|
|
||
|
|
|
||
|
|
def load_model(model_path=None, device=None):
|
||
|
|
"""Wczytuje model z pliku"""
|
||
|
|
if device is None:
|
||
|
|
device = sys_config.device
|
||
|
|
|
||
|
|
if model_path is None:
|
||
|
|
# Szukaj najnowszego modelu
|
||
|
|
models = list(Path(cfg.model_dir).glob("*.pt"))
|
||
|
|
if not models:
|
||
|
|
logger.error("❌ Brak modelu do wczytania!")
|
||
|
|
return None
|
||
|
|
|
||
|
|
# Preferuj najlepszy, potem finalny, potem najnowszy
|
||
|
|
best_path = Path(cfg.model_dir) / "minigpt_best.pt"
|
||
|
|
final_path = Path(cfg.model_dir) / "minigpt_final.pt"
|
||
|
|
|
||
|
|
if best_path.exists():
|
||
|
|
model_path = best_path
|
||
|
|
logger.info("📂 Wczytuję najlepszy model")
|
||
|
|
elif final_path.exists():
|
||
|
|
model_path = final_path
|
||
|
|
logger.info("📂 Wczytuję finalny model")
|
||
|
|
else:
|
||
|
|
models.sort(key=lambda x: x.stat().st_mtime, reverse=True)
|
||
|
|
model_path = models[0]
|
||
|
|
logger.info(f"📂 Wczytuję najnowszy model: {model_path.name}")
|
||
|
|
|
||
|
|
logger.info(f"🔄 Wczytywanie modelu: {model_path}")
|
||
|
|
|
||
|
|
try:
|
||
|
|
checkpoint = torch.load(model_path, map_location='cpu')
|
||
|
|
|
||
|
|
# Stwórz model
|
||
|
|
model = MiniGPT60M()
|
||
|
|
|
||
|
|
if 'model_state_dict' in checkpoint:
|
||
|
|
model.load_state_dict(checkpoint['model_state_dict'])
|
||
|
|
else:
|
||
|
|
# Stary format - same wagi
|
||
|
|
model.load_state_dict(checkpoint)
|
||
|
|
|
||
|
|
model.to(device)
|
||
|
|
model.eval()
|
||
|
|
|
||
|
|
logger.info(f"✅ Model wczytany pomyślnie")
|
||
|
|
|
||
|
|
# Pokaż informacje
|
||
|
|
if 'config' in checkpoint:
|
||
|
|
logger.info(f" • Config: {checkpoint['config'].get('n_layers', 'N/A')} warstw")
|
||
|
|
|
||
|
|
if 'loss' in checkpoint:
|
||
|
|
logger.info(f" • Loss: {checkpoint['loss']:.4f}")
|
||
|
|
|
||
|
|
if 'epoch' in checkpoint:
|
||
|
|
logger.info(f" • Epoka: {checkpoint['epoch']}")
|
||
|
|
|
||
|
|
return model
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"❌ Błąd wczytywania modelu: {e}")
|
||
|
|
return None
|
||
|
|
|
||
|
|
|
||
|
|
def generate_text(prompt, model_path=None, device=None, max_length=200, temperature=0.8):
|
||
|
|
"""Generuje tekst na podstawie promptu"""
|
||
|
|
logger.info(f"🎨 GENERUJĘ TEKST: '{prompt}'")
|
||
|
|
|
||
|
|
# Wczytaj model
|
||
|
|
model = load_model(model_path, device)
|
||
|
|
if model is None:
|
||
|
|
logger.error("❌ Nie można wczytać modelu!")
|
||
|
|
return
|
||
|
|
|
||
|
|
# Generuj
|
||
|
|
start_time = time.time()
|
||
|
|
generated = model.generate_text(prompt, max_len=max_length, temperature=temperature)
|
||
|
|
gen_time = time.time() - start_time
|
||
|
|
|
||
|
|
# Pokaż wynik
|
||
|
|
logger.info("=" * 60)
|
||
|
|
logger.info(f"📝 WYJŚCIE:")
|
||
|
|
logger.info(generated)
|
||
|
|
logger.info("=" * 60)
|
||
|
|
logger.info(f"⏱️ Czas generowania: {gen_time:.2f}s")
|
||
|
|
logger.info(f"📏 Długość: {len(generated)} znaków")
|
||
|
|
|
||
|
|
# Zapisz do pliku
|
||
|
|
output_dir = Path("results")
|
||
|
|
output_dir.mkdir(exist_ok=True)
|
||
|
|
|
||
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
|
|
output_file = output_dir / f"generated_{timestamp}.txt"
|
||
|
|
|
||
|
|
with open(output_file, 'w', encoding='utf-8') as f:
|
||
|
|
f.write(f"Prompt: {prompt}\n")
|
||
|
|
f.write(f"Generated at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
|
||
|
|
f.write(f"Temperature: {temperature}\n")
|
||
|
|
f.write(f"Max length: {max_length}\n")
|
||
|
|
f.write("-" * 60 + "\n")
|
||
|
|
f.write(generated + "\n")
|
||
|
|
|
||
|
|
logger.info(f"💾 Wynik zapisany: {output_file}")
|
||
|
|
|
||
|
|
return generated
|
||
|
|
|
||
|
|
|
||
|
|
def start_chat(model_path=None, device=None):
|
||
|
|
"""Uruchamia tryb rozmowy z modelem"""
|
||
|
|
logger.info("💬 ROZPOCZYNAM TRYB ROZMOWY")
|
||
|
|
logger.info("=" * 60)
|
||
|
|
logger.info("Wpisz 'exit' aby wyjść")
|
||
|
|
logger.info("Wpisz 'reset' aby zresetować kontekst")
|
||
|
|
logger.info("=" * 60)
|
||
|
|
|
||
|
|
# Wczytaj model
|
||
|
|
model = load_model(model_path, device)
|
||
|
|
if model is None:
|
||
|
|
return
|
||
|
|
|
||
|
|
# Historia rozmowy
|
||
|
|
conversation_history = []
|
||
|
|
max_history = 5
|
||
|
|
|
||
|
|
while True:
|
||
|
|
try:
|
||
|
|
# Pobierz input
|
||
|
|
user_input = input("\n👤 Ty: ").strip()
|
||
|
|
|
||
|
|
if user_input.lower() == 'exit':
|
||
|
|
logger.info("👋 Do widzenia!")
|
||
|
|
break
|
||
|
|
|
||
|
|
if user_input.lower() == 'reset':
|
||
|
|
conversation_history = []
|
||
|
|
logger.info("🔄 Historia zresetowana")
|
||
|
|
continue
|
||
|
|
|
||
|
|
if not user_input:
|
||
|
|
continue
|
||
|
|
|
||
|
|
# Przygotuj prompt z historią
|
||
|
|
if conversation_history:
|
||
|
|
prompt = "\n".join(conversation_history[-max_history:]) + "\n👤 " + user_input + "\n🤖 "
|
||
|
|
else:
|
||
|
|
prompt = "👤 " + user_input + "\n🤖 "
|
||
|
|
|
||
|
|
# Generuj odpowiedź
|
||
|
|
logger.info("🤖 Myślę...")
|
||
|
|
start_time = time.time()
|
||
|
|
|
||
|
|
response = model.generate_text(
|
||
|
|
prompt,
|
||
|
|
max_len=200,
|
||
|
|
temperature=0.8
|
||
|
|
)
|
||
|
|
|
||
|
|
# Wyodrębnij tylko odpowiedź modelu
|
||
|
|
if "🤖 " in response:
|
||
|
|
response = response.split("🤖 ")[-1].strip()
|
||
|
|
|
||
|
|
gen_time = time.time() - start_time
|
||
|
|
|
||
|
|
# Wyświetl odpowiedź
|
||
|
|
print(f"🤖 AI: {response}")
|
||
|
|
print(f" ⏱️ {gen_time:.2f}s")
|
||
|
|
|
||
|
|
# Dodaj do historii
|
||
|
|
conversation_history.append(f"👤 {user_input}")
|
||
|
|
conversation_history.append(f"🤖 {response}")
|
||
|
|
|
||
|
|
# Ogranicz historię
|
||
|
|
if len(conversation_history) > max_history * 2:
|
||
|
|
conversation_history = conversation_history[-(max_history * 2):]
|
||
|
|
|
||
|
|
except KeyboardInterrupt:
|
||
|
|
logger.info("\n👋 Przerwano przez użytkownika")
|
||
|
|
break
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"❌ Błąd: {e}")
|
||
|
|
continue
|
||
|
|
|
||
|
|
|
||
|
|
def evaluate_model(model_path=None, device=None, test_samples=100):
|
||
|
|
"""Ocenia model na danych testowych"""
|
||
|
|
logger.info("📊 OCENIAM MODEL")
|
||
|
|
|
||
|
|
# Wczytaj model
|
||
|
|
model = load_model(model_path, device)
|
||
|
|
if model is None:
|
||
|
|
return
|
||
|
|
|
||
|
|
# Wczytaj dane
|
||
|
|
data_file = Path(cfg.prepared_dir) / "all_data.txt"
|
||
|
|
if not data_file.exists():
|
||
|
|
logger.error("❌ Brak danych do ewaluacji")
|
||
|
|
return
|
||
|
|
|
||
|
|
dataset = TextDataset(str(data_file), max_len=cfg.max_len)
|
||
|
|
|
||
|
|
# Wybierz próbki testowe
|
||
|
|
test_indices = random.sample(range(len(dataset)), min(test_samples, len(dataset)))
|
||
|
|
|
||
|
|
model.eval()
|
||
|
|
total_loss = 0
|
||
|
|
total_perplexity = 0
|
||
|
|
|
||
|
|
logger.info(f"🔍 Testowanie na {len(test_indices)} próbkach...")
|
||
|
|
|
||
|
|
with torch.no_grad():
|
||
|
|
for i, idx in enumerate(test_indices):
|
||
|
|
x, y = dataset[idx]
|
||
|
|
x = x.unsqueeze(0).to(device)
|
||
|
|
y = y.unsqueeze(0).to(device)
|
||
|
|
|
||
|
|
_, loss = model(x, y)
|
||
|
|
total_loss += loss.item()
|
||
|
|
|
||
|
|
# Perplexity
|
||
|
|
perplexity = torch.exp(loss).item()
|
||
|
|
total_perplexity += perplexity
|
||
|
|
|
||
|
|
if (i + 1) % 10 == 0:
|
||
|
|
logger.info(f" Przetworzono {i + 1}/{len(test_indices)}...")
|
||
|
|
|
||
|
|
avg_loss = total_loss / len(test_indices)
|
||
|
|
avg_perplexity = total_perplexity / len(test_indices)
|
||
|
|
|
||
|
|
logger.info("=" * 60)
|
||
|
|
logger.info("📈 WYNIKI EWALUACJI:")
|
||
|
|
logger.info(f" • Średni Loss: {avg_loss:.4f}")
|
||
|
|
logger.info(f" • Perplexity: {avg_perplexity:.2f}")
|
||
|
|
logger.info(f" • Próbki testowe: {len(test_indices)}")
|
||
|
|
logger.info("=" * 60)
|
||
|
|
|
||
|
|
# Zapisz wyniki
|
||
|
|
eval_dir = Path("evaluation")
|
||
|
|
eval_dir.mkdir(exist_ok=True)
|
||
|
|
|
||
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
|
|
eval_file = eval_dir / f"eval_{timestamp}.json"
|
||
|
|
|
||
|
|
results = {
|
||
|
|
'timestamp': timestamp,
|
||
|
|
'avg_loss': float(avg_loss),
|
||
|
|
'avg_perplexity': float(avg_perplexity),
|
||
|
|
'test_samples': len(test_indices),
|
||
|
|
'model': str(model_path)
|
||
|
|
}
|
||
|
|
|
||
|
|
with open(eval_file, 'w', encoding='utf-8') as f:
|
||
|
|
json.dump(results, f, indent=2, ensure_ascii=False)
|
||
|
|
|
||
|
|
logger.info(f"💾 Wyniki zapisane: {eval_file}")
|
||
|
|
|
||
|
|
return avg_loss, avg_perplexity
|
||
|
|
|
||
|
|
|
||
|
|
# ==================== MAIN GUARD ====================
|
||
|
|
if __name__ == "__main__":
|
||
|
|
# Testowa funkcja
|
||
|
|
print("🤖 MiniGPT-60M AI Module")
|
||
|
|
print("Użyj: python main.py --train / --generate / --chat")
|
||
|
|
|
||
|
|
# Test tokenizera
|
||
|
|
test_text = "Hello world!"
|
||
|
|
encoded = tokenizer.encode(test_text)
|
||
|
|
decoded = tokenizer.decode(encoded)
|
||
|
|
|
||
|
|
print(f"\n🧪 Test tokenizera: '{test_text}' -> {encoded} -> '{decoded}'")
|
||
|
|
|
||
|
|
# Test modelu
|
||
|
|
model = MiniGPT60M()
|
||
|
|
test_input = torch.randint(0, cfg.vocab_size, (2, 16))
|
||
|
|
logits, _ = model(test_input)
|
||
|
|
|
||
|
|
print(f"🧪 Test modelu: input {test_input.shape} -> output {logits.shape}")
|
||
|
|
print(f"✅ Moduł AI gotowy do użycia!")
|