aai/ai.py

1247 lines
42 KiB
Python
Raw Normal View History

2026-01-26 15:19:15 +00:00
"""
🎯 AI.PY - MiniGPT-60M z Clay Checkpoint System
"""
import os
import sys
import time
import math
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from datetime import datetime, timedelta
from pathlib import Path
from typing import List, Dict, Any, Optional, Tuple
import random
import numpy as np
from config import logger, cfg, sys_config
# ==================== TOKENIZER ====================
class Tokenizer:
def __init__(self):
self.vocab = cfg.vocab
self.vocab_size = cfg.vocab_size
self.char_to_idx = {ch: i for i, ch in enumerate(self.vocab)}
self.idx_to_char = {i: ch for i, ch in enumerate(self.vocab)}
def encode(self, text: str) -> List[int]:
"""Enkoduje tekst na listę indeksów"""
return [self.char_to_idx.get(ch, 0) for ch in text if ch in self.char_to_idx]
def decode(self, indices: List[int]) -> str:
"""Dekoduje listę indeksów na tekst"""
return ''.join([self.idx_to_char.get(idx, '') for idx in indices])
def encode_batch(self, texts: List[str]) -> torch.Tensor:
"""Enkoduje batch tekstów"""
encoded = [self.encode(text) for text in texts]
max_len = max(len(e) for e in encoded)
padded = [e + [0] * (max_len - len(e)) for e in encoded]
return torch.tensor(padded, dtype=torch.long)
def decode_batch(self, tensors: torch.Tensor) -> List[str]:
"""Dekoduje batch tensorów"""
texts = []
for tensor in tensors:
indices = tensor.tolist()
# Usuń padding (0)
indices = [idx for idx in indices if idx != 0]
texts.append(self.decode(indices))
return texts
# Globalny tokenizer
tokenizer = Tokenizer()
# ==================== DATASET ====================
class TextDataset(Dataset):
def __init__(self, filepath: str, max_len: int = 512):
self.filepath = filepath
self.max_len = max_len
# Wczytaj dane
with open(filepath, 'r', encoding='utf-8') as f:
self.lines = [line.strip() for line in f if line.strip()]
logger.info(f"📊 Dataset: {len(self.lines):,} linii")
def __len__(self):
return len(self.lines)
def __getitem__(self, idx):
text = self.lines[idx]
# Przycinaj lub paduj do max_len
if len(text) > self.max_len:
start = random.randint(0, len(text) - self.max_len)
text = text[start:start + self.max_len]
# Enkoduj
encoded = tokenizer.encode(text)
# Paduj jeśli za krótkie
if len(encoded) < self.max_len:
encoded = encoded + [0] * (self.max_len - len(encoded))
x = torch.tensor(encoded[:-1], dtype=torch.long)
y = torch.tensor(encoded[1:], dtype=torch.long)
return x, y
# ==================== CLAY CHECKPOINT SYSTEM ====================
class ClayCheckpoint:
"""Zaawansowany system checkpointów z Clay UI"""
def __init__(self, model, optimizer, scheduler=None):
self.model = model
self.optimizer = optimizer
self.scheduler = scheduler
self.checkpoint_dir = Path(cfg.checkpoints_dir)
self.checkpoint_dir.mkdir(exist_ok=True)
# Training stats
self.stats = {
'start_time': time.time(),
'epoch_times': [],
'step_times': [],
'loss_history': [],
'learning_rates': [],
'best_loss': float('inf')
}
# Progress tracking
self.progress = {
'current_step': 0,
'total_steps': 0,
'current_epoch': 0,
'total_epochs': cfg.epochs,
'estimated_completion': None
}
# Progress bar style
self.progress_style = {
'filled': '',
'empty': '',
'arrow': '',
'spinner': ['', '', '', '', '', '', '', '', '', '']
}
def save(self, step, epoch, loss, is_best=False, force=False):
"""Zapisuje checkpoint"""
if not force and step % cfg.checkpoint_freq != 0:
return
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"clay_checkpoint_ep{epoch}_step{step}_{timestamp}.pt"
filepath = self.checkpoint_dir / filename
# Przygotuj stan
checkpoint = {
'step': step,
'epoch': epoch,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'loss': loss,
'stats': self.stats,
'progress': self.progress,
'timestamp': timestamp,
'config': {
'vocab_size': cfg.vocab_size,
'embed_dim': cfg.embed_dim,
'n_layers': cfg.n_layers,
'n_heads': cfg.n_heads
}
}
if self.scheduler:
checkpoint['scheduler_state_dict'] = self.scheduler.state_dict()
# Zapisz
torch.save(checkpoint, filepath)
# Zapisz jako JSON dla czytelności
json_path = filepath.with_suffix('.json')
json_data = {
'checkpoint_info': {
'filename': filename,
'step': step,
'epoch': epoch,
'loss': float(loss),
'timestamp': timestamp,
'file_size': os.path.getsize(filepath)
},
'training_stats': {
'total_time': time.time() - self.stats['start_time'],
'avg_loss': float(sum(self.stats['loss_history'][-100:]) / min(100, len(self.stats['loss_history']))),
'current_lr': float(self.optimizer.param_groups[0]['lr']),
'steps_done': self.progress['current_step']
}
}
with open(json_path, 'w', encoding='utf-8') as f:
json.dump(json_data, f, indent=2, ensure_ascii=False)
# Zachowaj tylko N ostatnich checkpointów
self.cleanup_old_checkpoints()
logger.info(f"💾 Clay Checkpoint: {filename} (loss: {loss:.4f})")
if is_best:
best_path = self.checkpoint_dir / "clay_best.pt"
torch.save(checkpoint, best_path)
logger.info(f"🏆 Nowy najlepszy model zapisany!")
def load_latest(self):
"""Wczytuje najnowszy checkpoint"""
checkpoints = list(self.checkpoint_dir.glob("clay_checkpoint_*.pt"))
if not checkpoints:
return None, None, None, None
# Znajdź najnowszy po czasie
latest = max(checkpoints, key=os.path.getmtime)
logger.info(f"🔄 Ładowanie Clay Checkpoint: {latest.name}")
checkpoint = torch.load(latest, map_location='cpu')
# Przywróć stan
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
if 'scheduler_state_dict' in checkpoint and self.scheduler:
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
# Przywróć statystyki
if 'stats' in checkpoint:
self.stats.update(checkpoint['stats'])
if 'progress' in checkpoint:
self.progress.update(checkpoint['progress'])
logger.info(f"✅ Wczytano checkpoint: epoka {checkpoint['epoch']}, krok {checkpoint['step']}")
return checkpoint['epoch'], checkpoint['step'], checkpoint['loss'], checkpoint.get('timestamp')
def cleanup_old_checkpoints(self):
"""Czyści stare checkpointy"""
checkpoints = list(self.checkpoint_dir.glob("clay_checkpoint_*.pt"))
checkpoints.sort(key=os.path.getmtime)
keep = cfg.get('keep_checkpoints', 5)
while len(checkpoints) > keep:
old = checkpoints.pop(0)
# Zostaw najlepszy
if "best" not in old.name:
old.unlink()
# Usuń też JSON
json_file = old.with_suffix('.json')
if json_file.exists():
json_file.unlink()
def estimate_time_remaining(self):
"""Szacuje pozostały czas treningu"""
if not self.stats['step_times']:
return "Obliczanie..."
avg_time_per_step = sum(self.stats['step_times'][-100:]) / min(100, len(self.stats['step_times']))
steps_done = self.progress['current_step']
total_steps = self.progress['total_steps']
if steps_done == 0 or total_steps == 0:
return "Obliczanie..."
steps_left = total_steps - steps_done
seconds_left = steps_left * avg_time_per_step
# Jeśli mamy dane z epok, użyj ich
if self.stats['epoch_times']:
epochs_done = self.progress['current_epoch']
epochs_left = self.progress['total_epochs'] - epochs_done
avg_epoch_time = sum(self.stats['epoch_times']) / len(self.stats['epoch_times'])
epoch_based = epochs_left * avg_epoch_time
seconds_left = min(seconds_left, epoch_based)
return self._format_time(seconds_left)
def _format_time(self, seconds):
"""Formatuje czas"""
if seconds < 60:
return f"{seconds:.0f}s"
elif seconds < 3600:
return f"{seconds/60:.1f}m"
elif seconds < 86400:
return f"{seconds/3600:.1f}h"
else:
return f"{seconds/86400:.1f}d"
def get_progress_bar(self, width=40):
"""Generuje pasek postępu"""
if self.progress['total_steps'] == 0:
return "[░░░░░░░░░░░░░░░░░░░░]"
progress = min(1.0, self.progress['current_step'] / self.progress['total_steps'])
filled = int(width * progress)
empty = width - filled
bar = self.progress_style['filled'] * filled
if filled < width:
bar += self.progress_style['arrow']
bar += self.progress_style['empty'] * empty
return f"[{bar}]"
def print_progress(self, step_loss, current_lr):
"""Wyświetla aktualny postęp"""
progress_bar = self.get_progress_bar(40)
# Oblicz procent
percent = (self.progress['current_step'] / self.progress['total_steps']) * 100
# Szacowany czas
eta = self.estimate_time_remaining()
# Spinner (animowany)
spinner_idx = int(time.time() * 4) % len(self.progress_style['spinner'])
spinner = self.progress_style['spinner'][spinner_idx]
# Formatuj wyjście
output = (f"\r{spinner} {progress_bar} {percent:5.1f}% | "
f"Step: {self.progress['current_step']:,}/{self.progress['total_steps']:,} | "
f"Loss: {step_loss:.4f} | LR: {current_lr:.6f} | "
f"ETA: {eta}")
print(output, end='', flush=True)
# Na końcu epoki, przejdź do nowej linii
if self.progress['current_step'] == self.progress['total_steps']:
print()
def start_epoch(self, epoch, total_batches):
"""Rozpoczyna nową epokę"""
self.progress['current_epoch'] = epoch
self.progress['total_steps'] = total_batches * self.progress['total_epochs']
self.epoch_start_time = time.time()
logger.info(f"\n🚀 ROZPOCZYNAM EPOKĘ {epoch+1}/{self.progress['total_epochs']}")
logger.info(f" • Kroki w epoce: {total_batches:,}")
logger.info(f" • Całkowite kroki: {self.progress['total_steps']:,}")
# Szacowanie czasu dla tej epoki
if self.stats['epoch_times']:
avg_epoch_time = sum(self.stats['epoch_times']) / len(self.stats['epoch_times'])
logger.info(f" • Szacowany czas epoki: {self._format_time(avg_epoch_time)}")
logger.info(f" • Szacowany czas do końca: {self.estimate_time_remaining()}")
def end_epoch(self, epoch_loss):
"""Kończy epokę"""
epoch_time = time.time() - self.epoch_start_time
self.stats['epoch_times'].append(epoch_time)
# Aktualizuj najlepszy loss
if epoch_loss < self.stats['best_loss']:
self.stats['best_loss'] = epoch_loss
# Oblicz pozostały czas
avg_epoch_time = sum(self.stats['epoch_times']) / len(self.stats['epoch_times'])
epochs_left = self.progress['total_epochs'] - self.progress['current_epoch'] - 1
total_time_left = avg_epoch_time * epochs_left
# Podsumowanie epoki
logger.info(f"📊 EPOKA {self.progress['current_epoch']+1} ZAKOŃCZONA:")
logger.info(f" • Loss: {epoch_loss:.4f}")
logger.info(f" • Najlepszy loss: {self.stats['best_loss']:.4f}")
logger.info(f" • Czas epoki: {self._format_time(epoch_time)}")
logger.info(f" • Średni czas/epokę: {self._format_time(avg_epoch_time)}")
logger.info(f" • Pozostało: ~{self._format_time(total_time_left)}")
# Przewidywane zakończenie
if epochs_left > 0:
eta_time = datetime.now() + timedelta(seconds=total_time_left)
logger.info(f" • Przewidywane zakończenie: {eta_time.strftime('%Y-%m-%d %H:%M:%S')}")
# ==================== MODEL ARCHITECTURE ====================
class AttentionHead(nn.Module):
def __init__(self, embed_dim, head_dim, dropout=0.1):
super().__init__()
self.head_dim = head_dim
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(embed_dim, 3 * head_dim)
self.proj = nn.Linear(head_dim, embed_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
B, T, C = x.shape
qkv = self.qkv(x)
q, k, v = qkv.chunk(3, dim=-1)
attn = (q @ k.transpose(-2, -1)) * self.scale
if mask is not None:
attn = attn.masked_fill(mask == 0, float('-inf'))
attn = attn.softmax(dim=-1)
attn = self.dropout(attn)
out = attn @ v
out = self.proj(out)
return out
class MultiHeadAttention(nn.Module):
def __init__(self, embed_dim, num_heads, dropout=0.1):
super().__init__()
assert embed_dim % num_heads == 0
self.head_dim = embed_dim // num_heads
self.num_heads = num_heads
self.heads = nn.ModuleList([
AttentionHead(embed_dim, self.head_dim, dropout)
for _ in range(num_heads)
])
self.proj = nn.Linear(embed_dim, embed_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
# Równoległe przetwarzanie przez wszystkie głowy
head_outputs = [head(x, mask) for head in self.heads]
# Konkatenacja
out = torch.cat(head_outputs, dim=-1)
out = self.proj(out)
out = self.dropout(out)
return out
class FeedForward(nn.Module):
def __init__(self, embed_dim, ff_dim, dropout=0.1):
super().__init__()
self.net = nn.Sequential(
nn.Linear(embed_dim, ff_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(ff_dim, embed_dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class TransformerBlock(nn.Module):
def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1):
super().__init__()
self.attention = MultiHeadAttention(embed_dim, num_heads, dropout)
self.ff = FeedForward(embed_dim, ff_dim, dropout)
self.norm1 = nn.LayerNorm(embed_dim)
self.norm2 = nn.LayerNorm(embed_dim)
def forward(self, x, mask=None):
# Self-attention z residual connection
attn_out = self.attention(self.norm1(x), mask)
x = x + attn_out
# Feed forward z residual connection
ff_out = self.ff(self.norm2(x))
x = x + ff_out
return x
class MiniGPT60M(nn.Module):
def __init__(self):
super().__init__()
# Embeddings
self.token_embedding = nn.Embedding(cfg.vocab_size, cfg.embed_dim)
self.position_embedding = nn.Embedding(cfg.max_len, cfg.embed_dim)
# Transformer blocks
self.blocks = nn.ModuleList([
TransformerBlock(cfg.embed_dim, cfg.n_heads, cfg.ff_dim, cfg.dropout)
for _ in range(cfg.n_layers)
])
# Final layers
self.norm = nn.LayerNorm(cfg.embed_dim)
self.head = nn.Linear(cfg.embed_dim, cfg.vocab_size)
# Initialize weights
self.apply(self._init_weights)
# Count parameters
total_params = sum(p.numel() for p in self.parameters())
trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
logger.info(f"🤖 Model MiniGPT-60M stworzony")
logger.info(f" • Parametry: {total_params:,} (trainable: {trainable_params:,})")
logger.info(f" • Embed dim: {cfg.embed_dim}")
logger.info(f" • Warstwy: {cfg.n_layers}")
logger.info(f" • Głowy: {cfg.n_heads}")
def _init_weights(self, module):
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, idx, targets=None):
B, T = idx.shape
# Embeddings
tok_emb = self.token_embedding(idx)
pos = torch.arange(T, device=idx.device).unsqueeze(0)
pos_emb = self.position_embedding(pos)
x = tok_emb + pos_emb
# Causal mask
mask = torch.tril(torch.ones(T, T, device=idx.device)).view(1, 1, T, T)
# Transformer blocks
for block in self.blocks:
x = block(x, mask)
# Final layer
x = self.norm(x)
logits = self.head(x)
# Loss if targets provided
loss = None
if targets is not None:
loss = nn.functional.cross_entropy(
logits.view(-1, logits.size(-1)),
targets.view(-1)
)
return logits, loss
def generate(self, input_ids=None, max_length=100, temperature=1.0, **kwargs):
"""Uniwersalna metoda generate obsługująca różne interfejsy"""
# Ignoruj nieobsługiwane argumenty (device, max_len, etc.)
max_len = kwargs.get('max_len', max_length)
if isinstance(input_ids, str):
return self.generate_text(input_ids, max_len=max_len, temperature=temperature)
elif input_ids is not None:
# Jeśli to tensor, przekonwertuj na tekst
if hasattr(input_ids, 'tolist'):
if len(input_ids.shape) > 1:
text = tokenizer.decode(input_ids[0].tolist())
else:
text = tokenizer.decode(input_ids.tolist())
else:
text = tokenizer.decode(input_ids)
return self.generate_text(text, max_len=max_len, temperature=temperature)
else:
return "" # Pusty string dla None input
def generate_text(self, prompt, max_len=100, temperature=1.0):
"""Generuje tekst na podstawie promptu"""
self.eval()
# Encode prompt
input_ids = tokenizer.encode(prompt)
if len(input_ids) == 0:
return prompt
# Przygotuj tensor
input_tensor = torch.tensor([input_ids], dtype=torch.long)
if torch.cuda.is_available():
input_tensor = input_tensor.cuda()
generated = input_ids.copy()
with torch.no_grad():
for _ in range(max_len):
# Forward pass
logits, _ = self(input_tensor)
# Weź logity dla ostatniego tokena
next_token_logits = logits[0, -1, :] / temperature
# Top-k sampling
if cfg.top_k > 0:
indices_to_remove = next_token_logits < torch.topk(next_token_logits, cfg.top_k)[0][..., -1, None]
next_token_logits[indices_to_remove] = float('-inf')
# Top-p sampling
if cfg.top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > cfg.top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
next_token_logits[indices_to_remove] = float('-inf')
# Sample next token
probs = torch.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1).item()
# Dodaj do sekwencji
generated.append(next_token)
# Aktualizuj input
input_tensor = torch.tensor([generated[-cfg.max_len:]], dtype=torch.long)
if torch.cuda.is_available():
input_tensor = input_tensor.cuda()
# Sprawdź token kończący
if next_token == tokenizer.char_to_idx.get('\n', -1):
break
# Decode
generated_text = tokenizer.decode(generated)
return generated_text
# ==================== TRAINING FUNCTIONS ====================
def create_optimizer(model, learning_rate=3e-4, weight_decay=0.1):
"""Tworzy optimizer z dekayem wag"""
decay_params = []
no_decay_params = []
for name, param in model.named_parameters():
if not param.requires_grad:
continue
if 'weight' in name and len(param.shape) > 1:
decay_params.append(param)
else:
no_decay_params.append(param)
optimizer = optim.AdamW([
{'params': decay_params, 'weight_decay': weight_decay},
{'params': no_decay_params, 'weight_decay': 0.0}
], lr=learning_rate, betas=(cfg.adam_beta1, cfg.adam_beta2), eps=cfg.adam_eps)
return optimizer
def create_scheduler(optimizer, warmup_steps=2000):
"""Tworzy scheduler z warmupem"""
scheduler = optim.lr_scheduler.OneCycleLR(
optimizer,
max_lr=cfg.learning_rate,
total_steps=cfg.epochs * 1000, # Szacunkowo
pct_start=warmup_steps / (cfg.epochs * 1000),
anneal_strategy='cos'
)
return scheduler
def benchmark_device():
"""Testuje wydajność urządzenia"""
logger.info("🏃‍♂️ BENCHMARK URZĄDZENIA:")
if torch.cuda.is_available():
device_name = torch.cuda.get_device_name(0)
memory_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
logger.info(f" • GPU: {device_name}")
logger.info(f" • Pamięć GPU: {memory_gb:.1f} GB")
# Test wydajności GPU
start = time.time()
x = torch.randn(1024, 768, device='cuda')
y = torch.randn(768, 512, device='cuda')
for _ in range(100):
_ = torch.matmul(x, y)
torch.cuda.synchronize()
gpu_time = time.time() - start
logger.info(f" • Wydajność GPU: {gpu_time:.2f}s na 100 operacji")
# Test CPU
x = torch.randn(1024, 768)
y = torch.randn(768, 512)
start = time.time()
for _ in range(100):
_ = torch.matmul(x, y)
cpu_time = time.time() - start
logger.info(f" • Wydajność CPU: {cpu_time:.2f}s na 100 operacji")
if torch.cuda.is_available():
speedup = cpu_time / gpu_time
logger.info(f" • Przyśpieszenie GPU vs CPU: {speedup:.1f}x")
def estimate_training_time(total_samples, batch_size, epochs, device):
"""Szacuje czas treningu przed rozpoczęciem"""
steps_per_epoch = math.ceil(total_samples / batch_size)
total_steps = steps_per_epoch * epochs
logger.info("⏱️ SZACOWANIE CZASU TRENINGU:")
logger.info(f" • Próbki: {total_samples:,}")
logger.info(f" • Epoki: {epochs}")
logger.info(f" • Batch size: {batch_size}")
logger.info(f" • Kroki: {total_steps:,}")
# Benchmarki dla różnych urządzeń (sekundy na 1000 kroków)
benchmarks = {
'cuda': {'T4': 120, 'P100': 90, 'V100': 60, 'A100': 30, 'default': 100},
'mps': {'M1': 150, 'M2': 100, 'default': 120},
'cpu': {'default': 600}
}
steps_in_k = total_steps / 1000
if device == 'cuda':
try:
gpu_name = torch.cuda.get_device_name(0)
if 'T4' in gpu_name:
time_per_k = benchmarks['cuda']['T4']
elif 'P100' in gpu_name:
time_per_k = benchmarks['cuda']['P100']
elif 'V100' in gpu_name:
time_per_k = benchmarks['cuda']['V100']
elif 'A100' in gpu_name:
time_per_k = benchmarks['cuda']['A100']
else:
time_per_k = benchmarks['cuda']['default']
except:
time_per_k = benchmarks['cuda']['default']
elif device == 'mps':
time_per_k = benchmarks['mps']['default']
else:
time_per_k = benchmarks['cpu']['default']
estimated_total = steps_in_k * time_per_k
# Formatuj czas
if estimated_total > 3600 * 24:
days = estimated_total / (3600 * 24)
time_str = f"~{days:.1f} dni"
elif estimated_total > 3600:
hours = estimated_total / 3600
time_str = f"~{hours:.1f} godzin"
elif estimated_total > 60:
minutes = estimated_total / 60
time_str = f"~{minutes:.1f} minut"
else:
time_str = f"~{estimated_total:.0f} sekund"
logger.info(f" • Szacowany czas: {time_str}")
# Przewidywane zakończenie
eta = datetime.now() + timedelta(seconds=estimated_total)
logger.info(f" • Przewidywane zakończenie: {eta.strftime('%Y-%m-%d %H:%M:%S')}")
return total_steps, estimated_total
def train_model(resume=False):
"""Główna funkcja treningowa"""
logger.info("=" * 60)
logger.info("🚀 ROZPOCZĘCIE TRENINGU MINIGPT-60M")
logger.info("=" * 60)
# Setup urządzenia
device = sys_config.device
if device == 'cuda' and not torch.cuda.is_available():
logger.warning("⚠ CUDA niedostępne, używam CPU")
device = 'cpu'
logger.info(f"🖥️ Urządzenie: {device.upper()}")
# Benchmark
benchmark_device()
# Wczytaj dane
data_file = Path(cfg.prepared_dir) / "all_data.txt"
if not data_file.exists():
logger.error(f"❌ Brak danych: {data_file}")
logger.info("💡 Uruchom: python main.py --prepare")
return
dataset = TextDataset(str(data_file), max_len=cfg.max_len)
train_loader = DataLoader(
dataset,
batch_size=cfg.batch_size,
shuffle=True,
num_workers=cfg.num_workers,
pin_memory=cfg.pin_memory
)
# Stwórz model
model = MiniGPT60M()
model.to(device)
# Loss function
criterion = nn.CrossEntropyLoss()
# Optimizer & scheduler
optimizer = create_optimizer(model, cfg.learning_rate, cfg.weight_decay)
scheduler = create_scheduler(optimizer, cfg.warmup_steps)
# Clay Checkpoint System
clay = ClayCheckpoint(model, optimizer, scheduler)
# Resume jeśli potrzebne
start_epoch = 0
start_step = 0
if resume:
loaded_epoch, loaded_step, loaded_loss, timestamp = clay.load_latest()
if loaded_epoch is not None:
start_epoch = loaded_epoch
start_step = loaded_step
logger.info(f"🔄 Wznawianie z epoki {start_epoch}, kroku {start_step}")
logger.info(f" • Ostatni loss: {loaded_loss:.4f}")
logger.info(f" • Timestamp: {timestamp}")
else:
logger.warning("⚠ Nie znaleziono checkpointu, zaczynam od początku")
# Szacowanie czasu
total_steps, estimated_total = estimate_training_time(
len(dataset), cfg.batch_size, cfg.epochs - start_epoch, device
)
logger.info("=" * 60)
logger.info("🎬 ROZPOCZĘTO TRENING!")
logger.info("=" * 60)
# Główna pętla treningowa
best_loss = float('inf')
for epoch in range(start_epoch, cfg.epochs):
clay.start_epoch(epoch, len(train_loader))
epoch_loss = 0
model.train()
for batch_idx, (x, y) in enumerate(train_loader):
step_start_time = time.time()
x, y = x.to(device), y.to(device)
# Forward pass
logits, loss = model(x, y)
# Backward pass
optimizer.zero_grad()
loss.backward()
# Gradient clipping
torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.clip_grad)
optimizer.step()
if scheduler:
scheduler.step()
# Aktualizuj statystyki
step_time = time.time() - step_start_time
clay.stats['step_times'].append(step_time)
clay.stats['loss_history'].append(loss.item())
clay.stats['learning_rates'].append(optimizer.param_groups[0]['lr'])
# Aktualizuj progress
clay.progress['current_step'] = epoch * len(train_loader) + batch_idx
epoch_loss += loss.item()
# Wyświetl progress co 10 kroków lub 1% batchy
if batch_idx % max(1, len(train_loader) // 100) == 0 or batch_idx < 10:
clay.print_progress(loss.item(), optimizer.param_groups[0]['lr'])
# Zapisz checkpoint co 1000 kroków
if clay.progress['current_step'] % 1000 == 0 and clay.progress['current_step'] > 0:
is_best = loss.item() < best_loss
if is_best:
best_loss = loss.item()
clay.save(
clay.progress['current_step'],
epoch,
loss.item(),
is_best=is_best
)
# Koniec epoki
avg_epoch_loss = epoch_loss / len(train_loader)
clay.end_epoch(avg_epoch_loss)
# Zapisz model co epokę
model_path = Path(cfg.model_dir) / f"minigpt_epoch_{epoch + 1}.pt"
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
'loss': avg_epoch_loss,
'config': cfg.__dict__
}, model_path)
logger.info(f"💾 Model zapisany: {model_path.name}")
# Zapisz też jako najlepszy jeśli to najlepszy loss
if avg_epoch_loss < best_loss:
best_loss = avg_epoch_loss
best_path = Path(cfg.model_dir) / "minigpt_best.pt"
torch.save(model.state_dict(), best_path)
logger.info(f"🏆 Nowy najlepszy model! Loss: {avg_epoch_loss:.4f}")
# ==================== KONIEC TRENINGU ====================
total_time = time.time() - clay.stats['start_time']
logger.info("=" * 60)
logger.info("🎉 TRENING ZAKOŃCZONY!")
logger.info("=" * 60)
# Podsumowanie statystyk
logger.info(f"📊 PODSUMOWANIE:")
logger.info(f" • Całkowity czas: {clay._format_time(total_time)}")
logger.info(f" • Ostatni loss: {clay.stats['loss_history'][-1]:.4f}")
logger.info(f" • Najlepszy loss: {clay.stats['best_loss']:.4f}")
logger.info(f" • Średni loss: {sum(clay.stats['loss_history']) / len(clay.stats['loss_history']):.4f}")
logger.info(f" • Średni czas/krok: {sum(clay.stats['step_times']) / len(clay.stats['step_times']):.3f}s")
logger.info(f" • Średni czas/epokę: {sum(clay.stats['epoch_times']) / len(clay.stats['epoch_times']):.1f}s")
logger.info(f" • Całkowite kroki: {clay.progress['current_step']:,}")
# Zapisz finalne statystyki
stats_path = Path(cfg.checkpoints_dir) / "training_stats.json"
final_stats = {
'total_time': total_time,
'best_loss': float(clay.stats['best_loss']),
'final_loss': float(clay.stats['loss_history'][-1]),
'avg_loss': float(sum(clay.stats['loss_history']) / len(clay.stats['loss_history'])),
'total_steps': int(clay.progress['current_step']),
'total_epochs': int(cfg.epochs),
'learning_rate_history': [float(lr) for lr in clay.stats['learning_rates']],
'loss_history': [float(loss) for loss in clay.stats['loss_history']],
'completion_time': datetime.now().strftime("%Y-%m-%d %H:%M:%S")
}
with open(stats_path, 'w', encoding='utf-8') as f:
json.dump(final_stats, f, indent=2, ensure_ascii=False)
logger.info(f"📈 Statystyki zapisane: {stats_path}")
# Zapisz finalny model
final_path = Path(cfg.model_dir) / "minigpt_final.pt"
torch.save({
'model_state_dict': model.state_dict(),
'config': cfg.__dict__,
'stats': final_stats
}, final_path)
logger.info(f"💾 Finalny model zapisany: {final_path}")
# Test generowania
logger.info(f"\n🧪 TEST GENEROWANIA:")
test_prompts = [
"Witaj, ",
"Python to ",
"Dzisiaj jest ",
"AI to "
]
for prompt in test_prompts:
generated = model.generate_text(prompt, max_len=50, temperature=0.8)
logger.info(f"'{prompt}' -> '{generated[:50]}...'")
logger.info("=" * 60)
logger.info("✅ TRENING ZAKOŃCZONY POMYŚLNIE!")
logger.info("=" * 60)
def continue_training():
"""Kontynuuje trening od ostatniego checkpointu"""
logger.info("🔄 KONTYNUUJĘ TRENING OD CHECKPOINTU")
train_model(resume=True)
def train_more_epochs(additional_epochs=3):
"""Dodaje więcej epok treningu"""
logger.info(f"📈 DODAJĘ {additional_epochs} EPOK TRENINGU")
# Zaktualizuj liczbę epok w konfiguracji
original_epochs = cfg.epochs
cfg.epochs = original_epochs + additional_epochs
logger.info(f" • Obecne epoki: {original_epochs}")
logger.info(f" • Nowe epoki: {cfg.epochs}")
logger.info(f" • Dodaję: {additional_epochs} epok")
# Kontynuuj trening
train_model(resume=True)
# Przywróć oryginalną wartość
cfg.epochs = original_epochs
def load_model(model_path=None, device=None):
"""Wczytuje model z pliku"""
if device is None:
device = sys_config.device
if model_path is None:
# Szukaj najnowszego modelu
models = list(Path(cfg.model_dir).glob("*.pt"))
if not models:
logger.error("❌ Brak modelu do wczytania!")
return None
# Preferuj najlepszy, potem finalny, potem najnowszy
best_path = Path(cfg.model_dir) / "minigpt_best.pt"
final_path = Path(cfg.model_dir) / "minigpt_final.pt"
if best_path.exists():
model_path = best_path
logger.info("📂 Wczytuję najlepszy model")
elif final_path.exists():
model_path = final_path
logger.info("📂 Wczytuję finalny model")
else:
models.sort(key=lambda x: x.stat().st_mtime, reverse=True)
model_path = models[0]
logger.info(f"📂 Wczytuję najnowszy model: {model_path.name}")
logger.info(f"🔄 Wczytywanie modelu: {model_path}")
try:
checkpoint = torch.load(model_path, map_location='cpu')
# Stwórz model
model = MiniGPT60M()
if 'model_state_dict' in checkpoint:
model.load_state_dict(checkpoint['model_state_dict'])
else:
# Stary format - same wagi
model.load_state_dict(checkpoint)
model.to(device)
model.eval()
logger.info(f"✅ Model wczytany pomyślnie")
# Pokaż informacje
if 'config' in checkpoint:
logger.info(f" • Config: {checkpoint['config'].get('n_layers', 'N/A')} warstw")
if 'loss' in checkpoint:
logger.info(f" • Loss: {checkpoint['loss']:.4f}")
if 'epoch' in checkpoint:
logger.info(f" • Epoka: {checkpoint['epoch']}")
return model
except Exception as e:
logger.error(f"❌ Błąd wczytywania modelu: {e}")
return None
def generate_text(prompt, model_path=None, device=None, max_length=200, temperature=0.8):
"""Generuje tekst na podstawie promptu"""
logger.info(f"🎨 GENERUJĘ TEKST: '{prompt}'")
# Wczytaj model
model = load_model(model_path, device)
if model is None:
logger.error("❌ Nie można wczytać modelu!")
return
# Generuj
start_time = time.time()
generated = model.generate_text(prompt, max_len=max_length, temperature=temperature)
gen_time = time.time() - start_time
# Pokaż wynik
logger.info("=" * 60)
logger.info(f"📝 WYJŚCIE:")
logger.info(generated)
logger.info("=" * 60)
logger.info(f"⏱️ Czas generowania: {gen_time:.2f}s")
logger.info(f"📏 Długość: {len(generated)} znaków")
# Zapisz do pliku
output_dir = Path("results")
output_dir.mkdir(exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_file = output_dir / f"generated_{timestamp}.txt"
with open(output_file, 'w', encoding='utf-8') as f:
f.write(f"Prompt: {prompt}\n")
f.write(f"Generated at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
f.write(f"Temperature: {temperature}\n")
f.write(f"Max length: {max_length}\n")
f.write("-" * 60 + "\n")
f.write(generated + "\n")
logger.info(f"💾 Wynik zapisany: {output_file}")
return generated
def start_chat(model_path=None, device=None):
"""Uruchamia tryb rozmowy z modelem"""
logger.info("💬 ROZPOCZYNAM TRYB ROZMOWY")
logger.info("=" * 60)
logger.info("Wpisz 'exit' aby wyjść")
logger.info("Wpisz 'reset' aby zresetować kontekst")
logger.info("=" * 60)
# Wczytaj model
model = load_model(model_path, device)
if model is None:
return
# Historia rozmowy
conversation_history = []
max_history = 5
while True:
try:
# Pobierz input
user_input = input("\n👤 Ty: ").strip()
if user_input.lower() == 'exit':
logger.info("👋 Do widzenia!")
break
if user_input.lower() == 'reset':
conversation_history = []
logger.info("🔄 Historia zresetowana")
continue
if not user_input:
continue
# Przygotuj prompt z historią
if conversation_history:
prompt = "\n".join(conversation_history[-max_history:]) + "\n👤 " + user_input + "\n🤖 "
else:
prompt = "👤 " + user_input + "\n🤖 "
# Generuj odpowiedź
logger.info("🤖 Myślę...")
start_time = time.time()
response = model.generate_text(
prompt,
max_len=200,
temperature=0.8
)
# Wyodrębnij tylko odpowiedź modelu
if "🤖 " in response:
response = response.split("🤖 ")[-1].strip()
gen_time = time.time() - start_time
# Wyświetl odpowiedź
print(f"🤖 AI: {response}")
print(f" ⏱️ {gen_time:.2f}s")
# Dodaj do historii
conversation_history.append(f"👤 {user_input}")
conversation_history.append(f"🤖 {response}")
# Ogranicz historię
if len(conversation_history) > max_history * 2:
conversation_history = conversation_history[-(max_history * 2):]
except KeyboardInterrupt:
logger.info("\n👋 Przerwano przez użytkownika")
break
except Exception as e:
logger.error(f"❌ Błąd: {e}")
continue
def evaluate_model(model_path=None, device=None, test_samples=100):
"""Ocenia model na danych testowych"""
logger.info("📊 OCENIAM MODEL")
# Wczytaj model
model = load_model(model_path, device)
if model is None:
return
# Wczytaj dane
data_file = Path(cfg.prepared_dir) / "all_data.txt"
if not data_file.exists():
logger.error("❌ Brak danych do ewaluacji")
return
dataset = TextDataset(str(data_file), max_len=cfg.max_len)
# Wybierz próbki testowe
test_indices = random.sample(range(len(dataset)), min(test_samples, len(dataset)))
model.eval()
total_loss = 0
total_perplexity = 0
logger.info(f"🔍 Testowanie na {len(test_indices)} próbkach...")
with torch.no_grad():
for i, idx in enumerate(test_indices):
x, y = dataset[idx]
x = x.unsqueeze(0).to(device)
y = y.unsqueeze(0).to(device)
_, loss = model(x, y)
total_loss += loss.item()
# Perplexity
perplexity = torch.exp(loss).item()
total_perplexity += perplexity
if (i + 1) % 10 == 0:
logger.info(f" Przetworzono {i + 1}/{len(test_indices)}...")
avg_loss = total_loss / len(test_indices)
avg_perplexity = total_perplexity / len(test_indices)
logger.info("=" * 60)
logger.info("📈 WYNIKI EWALUACJI:")
logger.info(f" • Średni Loss: {avg_loss:.4f}")
logger.info(f" • Perplexity: {avg_perplexity:.2f}")
logger.info(f" • Próbki testowe: {len(test_indices)}")
logger.info("=" * 60)
# Zapisz wyniki
eval_dir = Path("evaluation")
eval_dir.mkdir(exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
eval_file = eval_dir / f"eval_{timestamp}.json"
results = {
'timestamp': timestamp,
'avg_loss': float(avg_loss),
'avg_perplexity': float(avg_perplexity),
'test_samples': len(test_indices),
'model': str(model_path)
}
with open(eval_file, 'w', encoding='utf-8') as f:
json.dump(results, f, indent=2, ensure_ascii=False)
logger.info(f"💾 Wyniki zapisane: {eval_file}")
return avg_loss, avg_perplexity
# ==================== MAIN GUARD ====================
if __name__ == "__main__":
# Testowa funkcja
print("🤖 MiniGPT-60M AI Module")
print("Użyj: python main.py --train / --generate / --chat")
# Test tokenizera
test_text = "Hello world!"
encoded = tokenizer.encode(test_text)
decoded = tokenizer.decode(encoded)
print(f"\n🧪 Test tokenizera: '{test_text}' -> {encoded} -> '{decoded}'")
# Test modelu
model = MiniGPT60M()
test_input = torch.randint(0, cfg.vocab_size, (2, 16))
logits, _ = model(test_input)
print(f"🧪 Test modelu: input {test_input.shape} -> output {logits.shape}")
print(f"✅ Moduł AI gotowy do użycia!")