aai/final_fix.py

109 lines
4.1 KiB
Python
Raw Permalink Normal View History

2026-01-26 15:19:15 +00:00
#!/usr/bin/env python3
"""Ostateczna poprawka metody generate()"""
import re
with open('ai.py', 'r') as f:
content = f.read()
# Nowa, uniwersalna metoda generate
new_generate_method = '''
def generate(self, input_ids=None, max_length=100, temperature=1.0, **kwargs):
"""Uniwersalna metoda generate obsługująca różne interfejsy"""
# Ignoruj nieobsługiwane argumenty (device, max_len, etc.)
max_len = kwargs.get('max_len', max_length)
if isinstance(input_ids, str):
return self.generate_text(input_ids, max_len=max_len, temperature=temperature)
elif input_ids is not None:
# Jeśli to tensor, przekonwertuj na tekst
if hasattr(input_ids, 'tolist'):
if len(input_ids.shape) > 1:
text = self.tokenizer.decode(input_ids[0].tolist())
else:
text = self.tokenizer.decode(input_ids.tolist())
else:
text = self.tokenizer.decode(input_ids)
return self.generate_text(text, max_len=max_len, temperature=temperature)
else:
return "" # Pusty string dla None input
'''
# Znajdź i zamień metodę generate
if 'def generate(' in content:
# Użyj regex do znalezienia całej metody
pattern = r'def generate\(.*?\).*?(?=\n def \w+\(|\nclass \w+|\Z)'
# Sprawdź czy regex działa
match = re.search(pattern, content, re.DOTALL)
if match:
content = re.sub(pattern, new_generate_method, content, flags=re.DOTALL)
print("✅ Zaktualizowano metodę generate()")
else:
# Alternatywny sposób: znajdź między def generate a następną def
lines = content.split('\n')
new_lines = []
i = 0
while i < len(lines):
line = lines[i]
new_lines.append(line)
if line.strip().startswith('def generate('):
# Pomijaj stare linie metody aż do następnej metody
i += 1
while i < len(lines) and not lines[i].strip().startswith('def ') and not lines[i].strip().startswith('class '):
i += 1
# Dodaj nową metodę
new_lines.append(new_generate_method)
if i < len(lines):
new_lines.append(lines[i])
i += 1
continue
i += 1
content = '\n'.join(new_lines)
print("✅ Zastąpiono metodę generate() (alternatywna metoda)")
else:
print("❌ Metoda generate() nie znaleziona, dodaję...")
# Dodaj przed ostatnią metodą w klasie
if 'class MiniGPT60M' in content:
# Wstaw przed ostatnim 'def' w klasie
lines = content.split('\n')
new_lines = []
in_class = False
methods_found = []
for i, line in enumerate(lines):
new_lines.append(line)
if 'class MiniGPT60M' in line:
in_class = True
if in_class and line.strip().startswith('def '):
methods_found.append(i)
if methods_found:
# Dodaj przed ostatnią metodą
last_method_idx = methods_found[-1]
new_lines.insert(last_method_idx, '\n' + new_generate_method)
content = '\n'.join(new_lines)
print("✅ Dodano metodę generate()")
else:
# Dodaj na końcu klasy
if 'class MiniGPT60M' in content:
# Znajdź koniec klasy
class_pattern = r'(class MiniGPT60M.*?)(?=\nclass|\Z)'
match = re.search(class_pattern, content, re.DOTALL)
if match:
class_content = match.group(1)
updated_class = class_content.rstrip() + '\n\n' + new_generate_method
content = content.replace(class_content, updated_class)
print("✅ Dodano metodę generate() na końcu klasy")
# Zapisz zmiany
with open('ai.py', 'w') as f:
f.write(content)
print("✅ Plik ai.py zaktualizowany!")