This repository was archived by the owner on Jun 11, 2026. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsql_generation.py
More file actions
57 lines (40 loc) · 1.47 KB
/
sql_generation.py
File metadata and controls
57 lines (40 loc) · 1.47 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import re
from pydantic import BaseModel
class Input(BaseModel):
question: str
database_schema: str | None = None
class Output(BaseModel):
answer: str | None
duration: float | None = None
def extract_sql_text(sql_text: str) -> str:
sql_text_cleaned = re.sub(r"<think>(.*?)</think>", "", sql_text, flags=re.DOTALL)
lines = [
line.strip() for line in sql_text_cleaned.strip().split("\n") if line.strip()
]
for i in range(len(lines) - 1, -1, -1):
line = lines[i]
if line.upper().startswith("SELECT"):
sql_lines = [line]
for j in range(i + 1, len(lines)):
next_line = lines[j]
if (
next_line.upper().startswith("SELECT")
or next_line.startswith("//")
or next_line.startswith("#")
):
break
sql_lines.append(next_line)
if next_line.endswith(";"):
break
sql_query = " ".join(sql_lines).rstrip(".")
if not sql_query.endswith(";"):
sql_query += ";"
return sql_query
return sql_text_cleaned.strip()
if __name__ == "__main__":
from pharia_skill.testing import DevCsi
csi = DevCsi()
input = Input(question="Total number of customers per regions?")
# TODO: Implement the generate_sql skill
output = generate_sql(csi, input)
print(f"Generated SQL: {output.answer}")