109 lines
4.1 KiB
Python
109 lines
4.1 KiB
Python
#!/usr/bin/env python3
|
|
"""Ostateczna poprawka metody generate()"""
|
|
|
|
import re
|
|
|
|
with open('ai.py', 'r') as f:
|
|
content = f.read()
|
|
|
|
# Nowa, uniwersalna metoda generate
|
|
new_generate_method = '''
|
|
def generate(self, input_ids=None, max_length=100, temperature=1.0, **kwargs):
|
|
"""Uniwersalna metoda generate obsługująca różne interfejsy"""
|
|
# Ignoruj nieobsługiwane argumenty (device, max_len, etc.)
|
|
max_len = kwargs.get('max_len', max_length)
|
|
|
|
if isinstance(input_ids, str):
|
|
return self.generate_text(input_ids, max_len=max_len, temperature=temperature)
|
|
elif input_ids is not None:
|
|
# Jeśli to tensor, przekonwertuj na tekst
|
|
if hasattr(input_ids, 'tolist'):
|
|
if len(input_ids.shape) > 1:
|
|
text = self.tokenizer.decode(input_ids[0].tolist())
|
|
else:
|
|
text = self.tokenizer.decode(input_ids.tolist())
|
|
else:
|
|
text = self.tokenizer.decode(input_ids)
|
|
return self.generate_text(text, max_len=max_len, temperature=temperature)
|
|
else:
|
|
return "" # Pusty string dla None input
|
|
'''
|
|
|
|
# Znajdź i zamień metodę generate
|
|
if 'def generate(' in content:
|
|
# Użyj regex do znalezienia całej metody
|
|
pattern = r'def generate\(.*?\).*?(?=\n def \w+\(|\nclass \w+|\Z)'
|
|
|
|
# Sprawdź czy regex działa
|
|
match = re.search(pattern, content, re.DOTALL)
|
|
if match:
|
|
content = re.sub(pattern, new_generate_method, content, flags=re.DOTALL)
|
|
print("✅ Zaktualizowano metodę generate()")
|
|
else:
|
|
# Alternatywny sposób: znajdź między def generate a następną def
|
|
lines = content.split('\n')
|
|
new_lines = []
|
|
i = 0
|
|
while i < len(lines):
|
|
line = lines[i]
|
|
new_lines.append(line)
|
|
|
|
if line.strip().startswith('def generate('):
|
|
# Pomijaj stare linie metody aż do następnej metody
|
|
i += 1
|
|
while i < len(lines) and not lines[i].strip().startswith('def ') and not lines[i].strip().startswith('class '):
|
|
i += 1
|
|
# Dodaj nową metodę
|
|
new_lines.append(new_generate_method)
|
|
if i < len(lines):
|
|
new_lines.append(lines[i])
|
|
i += 1
|
|
continue
|
|
|
|
i += 1
|
|
|
|
content = '\n'.join(new_lines)
|
|
print("✅ Zastąpiono metodę generate() (alternatywna metoda)")
|
|
else:
|
|
print("❌ Metoda generate() nie znaleziona, dodaję...")
|
|
# Dodaj przed ostatnią metodą w klasie
|
|
if 'class MiniGPT60M' in content:
|
|
# Wstaw przed ostatnim 'def' w klasie
|
|
lines = content.split('\n')
|
|
new_lines = []
|
|
in_class = False
|
|
methods_found = []
|
|
|
|
for i, line in enumerate(lines):
|
|
new_lines.append(line)
|
|
|
|
if 'class MiniGPT60M' in line:
|
|
in_class = True
|
|
|
|
if in_class and line.strip().startswith('def '):
|
|
methods_found.append(i)
|
|
|
|
if methods_found:
|
|
# Dodaj przed ostatnią metodą
|
|
last_method_idx = methods_found[-1]
|
|
new_lines.insert(last_method_idx, '\n' + new_generate_method)
|
|
content = '\n'.join(new_lines)
|
|
print("✅ Dodano metodę generate()")
|
|
else:
|
|
# Dodaj na końcu klasy
|
|
if 'class MiniGPT60M' in content:
|
|
# Znajdź koniec klasy
|
|
class_pattern = r'(class MiniGPT60M.*?)(?=\nclass|\Z)'
|
|
match = re.search(class_pattern, content, re.DOTALL)
|
|
if match:
|
|
class_content = match.group(1)
|
|
updated_class = class_content.rstrip() + '\n\n' + new_generate_method
|
|
content = content.replace(class_content, updated_class)
|
|
print("✅ Dodano metodę generate() na końcu klasy")
|
|
|
|
# Zapisz zmiany
|
|
with open('ai.py', 'w') as f:
|
|
f.write(content)
|
|
|
|
print("✅ Plik ai.py zaktualizowany!")
|