diff --git a/packages/imandrax-codegen/CHANGELOG.md b/packages/imandrax-codegen/CHANGELOG.md index 7ed197b3..ba39a305 100644 --- a/packages/imandrax-codegen/CHANGELOG.md +++ b/packages/imandrax-codegen/CHANGELOG.md @@ -4,6 +4,10 @@ Versioning scheme: .. ## [Unreleased] +## [18.5.0] - 2026-03-19 + +- Public API for counter-example source code generation + ## [18.4.0] - 2026-03-06 - Return type definition and test declaration separately diff --git a/packages/imandrax-codegen/pyproject.toml b/packages/imandrax-codegen/pyproject.toml index 4d22e6d3..5a5ed273 100644 --- a/packages/imandrax-codegen/pyproject.toml +++ b/packages/imandrax-codegen/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "imandrax-codegen" -version = "18.4.2" +version = "18.5.0" description = "Code generator for ImandraX artifact" readme = "README.md" authors = [ @@ -10,8 +10,8 @@ requires-python = ">=3.12" dependencies = [ "devtools>=0.12.2", "dotenv>=0.9.9", - "imandrax-api-models>=18.0.0", - "imandrax-api[async]>=0.18.0.1", + "imandrax-api-models>=18.0.0,<19", + "imandrax-api[async]>=0.18.0.1,<0.19", "iml-query>=0.3.4", "pydantic>=2.12.3", "pyyaml>=6.0.3", @@ -77,6 +77,7 @@ reportUnnecessaryIsInstance = false reportUnnecessaryTypeIgnoreComment = true reportUnusedImport = false reportUnusedExpression = false +reportUnusedVariable='information' [tool.ruff] diff --git a/packages/imandrax-codegen/python/imandrax_codegen/gen_src.py b/packages/imandrax-codegen/python/imandrax_codegen/gen_src.py new file mode 100644 index 00000000..b53a2713 --- /dev/null +++ b/packages/imandrax-codegen/python/imandrax_codegen/gen_src.py @@ -0,0 +1,258 @@ +import os +import re +from pathlib import Path +from typing import Any, Literal, assert_never + +from imandrax_api import url_dev, url_prod +from imandrax_api_models import ( # noqa: F401, RUF100 + Art, + DecomposeRes, + EvalRes, + GetDeclsRes, + InstanceRes, + VerifyRes, +) +from imandrax_api_models.client import ImandraXClient +from imandrax_codegen.unparse import unparse + +from .art_parse import Lang, Mode, ast_of_art, code_of_art + +curr_dir = Path(__file__).parent + + +# TODO: Not used +def _get_fun_arg_types( # pyright: ignore[reportUnusedFunction] + fun_name: str, + iml: str, + c: ImandraXClient, +) -> list[str] | None: + """Get the argument types of a function.""" + tc_res = c.typecheck(iml) + name_ty_map = {ty.name: ty.ty for ty in tc_res.types} + if fun_name not in name_ty_map: + return None + + return list(map(lambda s: s.strip(), name_ty_map[fun_name].split('->'))) + + +def _extract_type_decl_names(iml_code: str) -> list[str]: + """ + Extract all type definition names from OCaml code using regex. + + Args: + ocaml_code: String containing OCaml code + + Returns: + List of type names defined in the code + + Examples: + >>> code = 'type direction = North | South' + >>> extract_ocaml_type_names(code) + ['direction'] + """ + # Pattern matches: "type" or "and" keyword followed by optional type parameters, then type name + # Handles both regular types and recursive types (type ... and ...) + # Also handles parameterized types: + # - Single param without parens: type 'a option + # - Multi param with parens: type ('a, 'b) container + # - Wildcard param: type _ expr (GADTs) + pattern = r'\b(?:type|and)\s+(?:(?:\([^)]+\)|\'[a-z_][a-zA-Z0-9_]*|_)\s+)?([a-z_][a-zA-Z0-9_]*(?:\s*,\s*[a-z_][a-zA-Z0-9_]*)*)' + + matches = re.finditer(pattern, iml_code) + type_names: list[str] = [] + + for match in matches: + # Extract the captured group (type name(s)) + names = match.group(1) + # Split by comma in case of mutually recursive types: type t1, t2 = ... + for name in names.split(','): + type_names.append(name.strip()) + + return type_names + + +class GenSourceCodeError(ValueError): + pass + + +def gen_source_code( + model_res: DecomposeRes | VerifyRes | InstanceRes, + lang: Lang, + decls_res: GetDeclsRes, +) -> tuple[str, str] | GenSourceCodeError: + mode: Mode + art: Art + + match model_res: + case DecomposeRes(): + mode = 'fun-decomp' + if model_res.artifact is None: + return GenSourceCodeError('No artifact in decompose response') + art = model_res.artifact + case VerifyRes(): + mode = 'model' + if ( + model_res.refuted is None + or (model_res.refuted.model is None) + or (model_res.refuted.model.artifact is None) + ): + return GenSourceCodeError( + 'No refuted model artifact in verify response' + ) + art = model_res.refuted.model.artifact + case InstanceRes(): + mode = 'model' + if ( + model_res.sat is None + or (model_res.sat.model is None) + or (model_res.sat.model.artifact is None) + ): + return GenSourceCodeError( + 'No satisfiable model artifact in instance response' + ) + art = model_res.sat.model.artifact + + match lang: + case 'typescript': + type_def_srcs = [ + code_of_art(decl.artifact, mode='decl', lang=lang) + for decl in decls_res.decls + ] + src_body = code_of_art(art, mode=mode, lang=lang) + return ( + '\n'.join(type_def_srcs), + src_body, + ) + case 'python': + # TODO(#20): + # Python still needs two-stage generation otherwise + # we get two `from __future__ import annotations` + type_defs_stmts = [ + ast_of_art(decl.artifact, mode='decl') for decl in decls_res.decls + ] + type_def_stmts = [stmt for stmts in type_defs_stmts for stmt in stmts] + body_stmts = ast_of_art(art, mode=mode) + type_def_src = ( + unparse(type_def_stmts, include_future_import=True) + if type_def_stmts + else '' + ) + src_body = unparse(body_stmts, include_future_import=False) + return (type_def_src, src_body) + + +# Main +# ==================== + + +def gen_test_cases( + iml: str, + decomp_name: str, + lang: Lang, + other_decomp_kwargs: dict[str, Any] | None = None, + imandrax_api_key: str | None = None, + imandrax_env: str | None = None, +) -> tuple[str, str]: + """Decomp, get decl, and generate test cases as source code. + + Return: + Tuple of (type declarations, test case definition) + """ + + other_decomp_kwargs = other_decomp_kwargs or {} + + env = imandrax_env or os.getenv('IMANDRAX_ENV', 'prod') + url = url_dev if env == 'dev' else url_prod + + c = ImandraXClient( + auth_token=imandrax_api_key or os.environ['IMANDRAX_API_KEY'], + url=url, + ) + + # Eval IML + eval_res: EvalRes = c.eval_src(iml) + if eval_res.success is not True: + error_msgs = [repr(err.msg) for err in eval_res.errors] + raise ValueError(f'Failed to evaluate source code: {error_msgs}') + + # Decomp + decomp_res: DecomposeRes = c.decompose(decomp_name, **other_decomp_kwargs) + decomp_art = decomp_res.artifact + assert decomp_art, 'No artifact returned from decompose' + + # Get type declarations + arg_types: list[str] = _extract_type_decl_names(iml) + decls: GetDeclsRes = c.get_decls(arg_types) + + src_res = gen_source_code(decomp_res, lang, decls) + if isinstance(src_res, GenSourceCodeError): + raise src_res + else: + return src_res + + +def gen_counter_example( + iml: str, + vg_src: str, + vg_type: Literal['verify', 'instance'], + lang: Lang, + vg_hint: str | None = None, + imandrax_api_key: str | None = None, + imandrax_env: str | None = None, +) -> tuple[str, str]: + """Decomp, get decl, and generate test cases as source code. + + Return: + Tuple of (type declarations, test case definition) + """ + + env = imandrax_env or os.getenv('IMANDRAX_ENV', 'prod') + url = url_dev if env == 'dev' else url_prod + + c = ImandraXClient( + auth_token=imandrax_api_key or os.environ['IMANDRAX_API_KEY'], + url=url, + ) + + # Eval IML + eval_res: EvalRes = c.eval_src(iml) + if eval_res.success is not True: + error_msgs = [repr(err.msg) for err in eval_res.errors] + raise ValueError(f'Failed to evaluate source code: {error_msgs}') + + # VG + match vg_type: + case 'verify': + model_res = c.verify_src(vg_src, vg_hint) + case 'instance': + model_res = c.instance_src(vg_src, vg_hint) + case _: + assert_never(vg_type) + + # Get type declarations + arg_types: list[str] = _extract_type_decl_names(iml) + decls: GetDeclsRes = c.get_decls(arg_types) + + src_res = gen_source_code(model_res, lang, decls) + if isinstance(src_res, GenSourceCodeError): + raise src_res + else: + return src_res + + +if __name__ == '__main__': + import dotenv + + dotenv.load_dotenv() + iml = """ + let f x = x + 1 + """ + + instance_src = 'fun x -> f x = 2' + + res = '\n'.join( + gen_counter_example( + iml=iml, vg_src=instance_src, vg_type='instance', lang='python' + ) + ) + print(res) diff --git a/packages/imandrax-codegen/python/imandrax_codegen/gen_tests.py b/packages/imandrax-codegen/python/imandrax_codegen/gen_tests.py deleted file mode 100644 index c720d63c..00000000 --- a/packages/imandrax-codegen/python/imandrax_codegen/gen_tests.py +++ /dev/null @@ -1,133 +0,0 @@ -import os -import re -from pathlib import Path -from typing import Any - -from imandrax_api import url_dev, url_prod -from imandrax_api_models import DecomposeRes, EvalRes # noqa: F401, RUF100 -from imandrax_api_models.client import ImandraXClient -from imandrax_codegen.unparse import unparse - -from .art_parse import Lang, ast_of_art, code_of_art - -curr_dir = Path(__file__).parent - - -# TODO: Not used -def _get_fun_arg_types( # pyright: ignore[reportUnusedFunction] - fun_name: str, - iml: str, - c: ImandraXClient, -) -> list[str] | None: - """Get the argument types of a function.""" - tc_res = c.typecheck(iml) - name_ty_map = {ty.name: ty.ty for ty in tc_res.types} - if fun_name not in name_ty_map: - return None - - return list(map(lambda s: s.strip(), name_ty_map[fun_name].split('->'))) - - -def _extract_type_decl_names(iml_code: str) -> list[str]: - """ - Extract all type definition names from OCaml code using regex. - - Args: - ocaml_code: String containing OCaml code - - Returns: - List of type names defined in the code - - Examples: - >>> code = 'type direction = North | South' - >>> extract_ocaml_type_names(code) - ['direction'] - """ - # Pattern matches: "type" or "and" keyword followed by optional type parameters, then type name - # Handles both regular types and recursive types (type ... and ...) - # Also handles parameterized types: - # - Single param without parens: type 'a option - # - Multi param with parens: type ('a, 'b) container - # - Wildcard param: type _ expr (GADTs) - pattern = r'\b(?:type|and)\s+(?:(?:\([^)]+\)|\'[a-z_][a-zA-Z0-9_]*|_)\s+)?([a-z_][a-zA-Z0-9_]*(?:\s*,\s*[a-z_][a-zA-Z0-9_]*)*)' - - matches = re.finditer(pattern, iml_code) - type_names: list[str] = [] - - for match in matches: - # Extract the captured group (type name(s)) - names = match.group(1) - # Split by comma in case of mutually recursive types: type t1, t2 = ... - for name in names.split(','): - type_names.append(name.strip()) - - return type_names - - -# Main -# ==================== - - -def gen_test_cases( - iml: str, - decomp_name: str, - lang: Lang, - other_decomp_kwargs: dict[str, Any] | None = None, - imandrax_api_key: str | None = None, - imandrax_env: str | None = None, -) -> tuple[str, str]: - """Decomp, get decl, and generate test cases as source code. - - Return: - Tuple of (type declarations, test case definition) - """ - - other_decomp_kwargs = other_decomp_kwargs or {} - - env = imandrax_env or os.getenv('IMANDRAX_ENV', 'prod') - url = url_dev if env == 'dev' else url_prod - - c = ImandraXClient( - auth_token=imandrax_api_key or os.environ['IMANDRAX_API_KEY'], - url=url, - ) - - # Eval IML - eval_res: EvalRes = c.eval_src(iml) - if eval_res.success is not True: - error_msgs = [repr(err.msg) for err in eval_res.errors] - raise ValueError(f'Failed to evaluate source code: {error_msgs}') - - decomp_res: DecomposeRes = c.decompose(decomp_name, **other_decomp_kwargs) - decomp_art = decomp_res.artifact - assert decomp_art, 'No artifact returned from decompose' - - arg_types: list[str] = _extract_type_decl_names(iml) - - decls = c.get_decls(arg_types) - - match lang: - case 'typescript': - type_def_srcs = [ - code_of_art(decl.artifact, mode='decl', lang=lang) - for decl in decls.decls - ] - test_def_src = code_of_art(decomp_art, mode='fun-decomp', lang=lang) - return ( - '\n'.join(type_def_srcs), - test_def_src, - ) - case 'python': - # TODO(#20): - # Python still needs two-stage generation otherwise - # we get two `from __future__ import annotations` - type_defs_stmts = [ - ast_of_art(decl.artifact, mode='decl') for decl in decls.decls - ] - type_def_stmts = [stmt for stmts in type_defs_stmts for stmt in stmts] - test_def_stmts = ast_of_art(decomp_art, mode='fun-decomp') - type_def_code = unparse(type_def_stmts) if type_def_stmts else '' - test_def_code = unparse( - test_def_stmts, include_future_import=not type_def_stmts - ) - return (type_def_code, test_def_code) diff --git a/packages/imandrax-codegen/test/python/test_extract_type_decl.py b/packages/imandrax-codegen/test/python/test_extract_type_decl.py index 82f33844..a4d6f543 100644 --- a/packages/imandrax-codegen/test/python/test_extract_type_decl.py +++ b/packages/imandrax-codegen/test/python/test_extract_type_decl.py @@ -1,6 +1,6 @@ """Test cases for extract_type_decl_names function.""" -from imandrax_codegen.gen_tests import ( +from imandrax_codegen.gen_src import ( _extract_type_decl_names, # pyright: ignore[reportPrivateUsage] ) diff --git a/packages/imandrax-codegen/test/python/test_unparse.py b/packages/imandrax-codegen/test/python/test_unparse.py index 4326efbb..101224d7 100644 --- a/packages/imandrax-codegen/test/python/test_unparse.py +++ b/packages/imandrax-codegen/test/python/test_unparse.py @@ -3,7 +3,7 @@ from typing import Any import yaml -from imandrax_codegen.gen_tests import Lang, gen_test_cases +from imandrax_codegen.gen_src import Lang, gen_test_cases from imandrax_codegen.unparse import join_code_parts from inline_snapshot import snapshot @@ -113,9 +113,6 @@ def test_nested_conditions(): lang='python', ) assert code == snapshot('''\ -from __future__ import annotations - - def test_1(): """test_1 @@ -179,9 +176,6 @@ def test_list_operations(): lang='python', ) assert code == snapshot('''\ -from __future__ import annotations - - def test_1(): """test_1 @@ -304,9 +298,6 @@ def test_composite_tuple(): lang='python', ) assert code == snapshot('''\ -from __future__ import annotations - - def test_1(): """test_1 @@ -356,9 +347,6 @@ def test_with_basis(): lang='python', ) assert code == snapshot('''\ -from __future__ import annotations - - def test_1(): """test_1 @@ -394,9 +382,6 @@ def test_primitive_real(): lang='python', ) assert code == snapshot('''\ -from __future__ import annotations - - def test_1(): """test_1 @@ -420,9 +405,6 @@ def test_multiple_parameters(): lang='python', ) assert code == snapshot('''\ -from __future__ import annotations - - def test_1(): """test_1 @@ -590,8 +572,6 @@ def test_option_type(): lang='python', ) assert code == snapshot('''\ -from __future__ import annotations - from dataclasses import dataclass from typing import Generic, TypeAlias, TypeVar @@ -655,9 +635,6 @@ def test_basic(): lang='python', ) assert code == snapshot('''\ -from __future__ import annotations - - def test_1(): """test_1 @@ -693,9 +670,6 @@ def test_primitive_bool(): lang='python', ) assert code == snapshot('''\ -from __future__ import annotations - - def test_1(): """test_1 @@ -759,9 +733,6 @@ def test_with_guards(): lang='python', ) assert code == snapshot('''\ -from __future__ import annotations - - def test_1(): """test_1 @@ -896,9 +867,6 @@ def test_primitive_int(): lang='python', ) assert code == snapshot('''\ -from __future__ import annotations - - def test_1(): """test_1 diff --git a/packages/imandrax-codegen/uv.lock b/packages/imandrax-codegen/uv.lock index 6bb7b866..7c53b021 100644 --- a/packages/imandrax-codegen/uv.lock +++ b/packages/imandrax-codegen/uv.lock @@ -437,7 +437,7 @@ wheels = [ [[package]] name = "imandrax-codegen" -version = "18.4.2" +version = "18.5.0" source = { editable = "." } dependencies = [ { name = "devtools" }, @@ -462,8 +462,8 @@ dev = [ requires-dist = [ { name = "devtools", specifier = ">=0.12.2" }, { name = "dotenv", specifier = ">=0.9.9" }, - { name = "imandrax-api", extras = ["async"], specifier = ">=0.18.0.1" }, - { name = "imandrax-api-models", specifier = ">=18.0.0" }, + { name = "imandrax-api", extras = ["async"], specifier = ">=0.18.0.1,<0.19" }, + { name = "imandrax-api-models", specifier = ">=18.0.0,<19" }, { name = "iml-query", specifier = ">=0.3.4" }, { name = "pydantic", specifier = ">=2.12.3" }, { name = "pyyaml", specifier = ">=6.0.3" },