#!/usr/bin/env python3 """Ostateczna poprawka metody generate()""" import re with open('ai.py', 'r') as f: content = f.read() # Nowa, uniwersalna metoda generate new_generate_method = ''' def generate(self, input_ids=None, max_length=100, temperature=1.0, **kwargs): """Uniwersalna metoda generate obsługująca różne interfejsy""" # Ignoruj nieobsługiwane argumenty (device, max_len, etc.) max_len = kwargs.get('max_len', max_length) if isinstance(input_ids, str): return self.generate_text(input_ids, max_len=max_len, temperature=temperature) elif input_ids is not None: # Jeśli to tensor, przekonwertuj na tekst if hasattr(input_ids, 'tolist'): if len(input_ids.shape) > 1: text = self.tokenizer.decode(input_ids[0].tolist()) else: text = self.tokenizer.decode(input_ids.tolist()) else: text = self.tokenizer.decode(input_ids) return self.generate_text(text, max_len=max_len, temperature=temperature) else: return "" # Pusty string dla None input ''' # Znajdź i zamień metodę generate if 'def generate(' in content: # Użyj regex do znalezienia całej metody pattern = r'def generate\(.*?\).*?(?=\n def \w+\(|\nclass \w+|\Z)' # Sprawdź czy regex działa match = re.search(pattern, content, re.DOTALL) if match: content = re.sub(pattern, new_generate_method, content, flags=re.DOTALL) print("✅ Zaktualizowano metodę generate()") else: # Alternatywny sposób: znajdź między def generate a następną def lines = content.split('\n') new_lines = [] i = 0 while i < len(lines): line = lines[i] new_lines.append(line) if line.strip().startswith('def generate('): # Pomijaj stare linie metody aż do następnej metody i += 1 while i < len(lines) and not lines[i].strip().startswith('def ') and not lines[i].strip().startswith('class '): i += 1 # Dodaj nową metodę new_lines.append(new_generate_method) if i < len(lines): new_lines.append(lines[i]) i += 1 continue i += 1 content = '\n'.join(new_lines) print("✅ Zastąpiono metodę generate() (alternatywna metoda)") else: print("❌ Metoda generate() nie znaleziona, dodaję...") # Dodaj przed ostatnią metodą w klasie if 'class MiniGPT60M' in content: # Wstaw przed ostatnim 'def' w klasie lines = content.split('\n') new_lines = [] in_class = False methods_found = [] for i, line in enumerate(lines): new_lines.append(line) if 'class MiniGPT60M' in line: in_class = True if in_class and line.strip().startswith('def '): methods_found.append(i) if methods_found: # Dodaj przed ostatnią metodą last_method_idx = methods_found[-1] new_lines.insert(last_method_idx, '\n' + new_generate_method) content = '\n'.join(new_lines) print("✅ Dodano metodę generate()") else: # Dodaj na końcu klasy if 'class MiniGPT60M' in content: # Znajdź koniec klasy class_pattern = r'(class MiniGPT60M.*?)(?=\nclass|\Z)' match = re.search(class_pattern, content, re.DOTALL) if match: class_content = match.group(1) updated_class = class_content.rstrip() + '\n\n' + new_generate_method content = content.replace(class_content, updated_class) print("✅ Dodano metodę generate() na końcu klasy") # Zapisz zmiany with open('ai.py', 'w') as f: f.write(content) print("✅ Plik ai.py zaktualizowany!")