diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 2b6d15c3e..39281eb04 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -27,7 +27,7 @@ "settings": { "python.testing.pytestEnabled": true, "python.testing.cwd": "${workspaceFolder}/tests", - "python.envFile": "${workspaceFolder}/argilla/.env.test", + "python.envFile": "${workspaceFolder}/extralit/.env.test", "python.testing.pytestArgs": [ "-vs", "--disable-warnings" ], "python.defaultInterpreterPath": "/opt/conda/bin/python", "python.condaPath": "/usr/local/bin/micromamba", diff --git a/.devcontainer/docker-compose/devcontainer.json b/.devcontainer/docker-compose/devcontainer.json index 9b7f01b71..e41959f44 100644 --- a/.devcontainer/docker-compose/devcontainer.json +++ b/.devcontainer/docker-compose/devcontainer.json @@ -97,14 +97,14 @@ }, "settings": { "python.testing.pytestEnabled": true, - "python.testing.cwd": "${workspaceFolder}/argilla/", + "python.testing.cwd": "${workspaceFolder}/extralit/", "python.testing.pytestArgs": [ "-vv", "--disable-warnings" ], "python.defaultInterpreterPath": "/opt/conda/bin/python", "python.condaPath": "/usr/local/bin/micromamba", - "python.envFile": "${workspaceFolder}/argilla/.env.test", + "python.envFile": "${workspaceFolder}/extralit/.env.test", "search.exclude": { "argilla-server/src/argilla_server/static/": true, "argilla-frontend/dist/": true, diff --git a/.devcontainer/docker-compose/setup.sh b/.devcontainer/docker-compose/setup.sh index c44b4a96a..37111575c 100644 --- a/.devcontainer/docker-compose/setup.sh +++ b/.devcontainer/docker-compose/setup.sh @@ -7,7 +7,7 @@ if ! pip list | grep -q "extralit"; then pdm config python.install_root /opt/conda/ uv pip install -q "sentence-transformers<3.0.0" transformers "textdescriptives<3.0.0" \ -e /workspaces/extralit/argilla-server/ && \ - uv pip install -q -e /workspaces/extralit/argilla/ + uv pip install -q -e /workspaces/extralit/extralit/ else echo "Package 'extralit' is already installed. Skipping installation." fi diff --git a/.devcontainer/setup.sh b/.devcontainer/setup.sh index 883f55620..c54b8ea65 100644 --- a/.devcontainer/setup.sh +++ b/.devcontainer/setup.sh @@ -19,7 +19,7 @@ if ! pip list | grep -q "extralit"; then pdm config use_uv true pdm config python.install_root /opt/conda/ uv pip install -e /workspaces/extralit/argilla-server/ - uv pip install -e /workspaces/extralit/argilla/ + uv pip install -e /workspaces/extralit/extralit/ else echo 'Package 'extralit' is already installed. Skipping installation.' fi diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 0c619cf7c..07e20e5a0 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -6,26 +6,26 @@ # The last matching pattern takes precedence. # SDK and Core Extraction -/argilla/* @extralit/sdk -/argilla/src/**/* @extralit/sdk -/argilla/tests/**/* @extralit/sdk -/argilla/pyproject.toml @extralit/sdk -/argilla/pdm.lock @extralit/sdk +/extralit/* @extralit/sdk +/extralit/src/**/* @extralit/sdk +/extralit/tests/**/* @extralit/sdk +/extralit/pyproject.toml @extralit/sdk +/extralit/pdm.lock @extralit/sdk # Backend Server -/argilla-server/**/* @extralit/backend +/extralit-server/**/* @extralit/backend # Frontend -/argilla-frontend/**/* @extralit/frontend +/extralit-frontend/**/* @extralit/frontend # Legacy compatibility layer -/argilla-v1/**/* @extralit/sdk @extralit/backend +/extralit-v1/**/* @extralit/sdk @extralit/backend # AI/ML-specific code -/argilla/src/extralit/{extraction,metrics,preprocessing,schema}/**/* @extralit/ai -/argilla/src/extralit/pipeline/**/* @extralit/ai @extralit/backend -/argilla/src/extralit/server/**/* @extralit/ai @extralit/backend -/argilla/src/extralit/storage/**/* @extralit/ai @extralit/backend +/extralit/src/extralit/{extraction,metrics,preprocessing,schema}/**/* @extralit/ai +/extralit/src/extralit/pipeline/**/* @extralit/ai @extralit/backend +/extralit/src/extralit/server/**/* @extralit/ai @extralit/backend +/extralit/src/extralit/storage/**/* @extralit/ai @extralit/backend # Infrastructure and configuration /.github/** @extralit/infra @@ -37,7 +37,7 @@ codecov.yml @extralit/infra # Documentation *.md @extralit/docs -/argilla/docs/**/* @extralit/docs +/extralit/docs/**/* @extralit/docs /examples/**/* @extralit/docs @extralit/sdk # Security and legal diff --git a/.github/workflows/README.md b/.github/workflows/README.md index 08b7efdb5..b5373d3a8 100644 --- a/.github/workflows/README.md +++ b/.github/workflows/README.md @@ -4,9 +4,9 @@ This directory contains GitHub Actions workflows for building, testing, and depl ## Key Workflows -### `argilla.yml` +### `extralit.yml` -Builds and publishes the `argilla` SDK Python package. +Builds and publishes the `extralit` SDK Python package. - **Trigger**: Push to main/develop/releases branches, pull requests, or manual dispatch - **Python versions**: 3.9, 3.10, 3.11, 3.12, 3.13 @@ -57,7 +57,7 @@ The workflows set various environment variables: Additional environment variables are set in specific workflows: - For `argilla-server.yml`: Database connection variables for Postgres, Elasticsearch, Redis, and MinIO -- For `argilla.yml`: HuggingFace credentials for integration tests +- For `extralit.yml`: HuggingFace credentials for integration tests ## Common Issues & Solutions diff --git a/.github/workflows/argilla-v1.yml b/.github/workflows/argilla-v1.yml index 52b0a7f79..213684103 100644 --- a/.github/workflows/argilla-v1.yml +++ b/.github/workflows/argilla-v1.yml @@ -41,27 +41,9 @@ jobs: use-mamba: true activate-environment: argilla - - name: Get date for conda cache - id: get-date - run: echo "::set-output name=today::$(/bin/date -u '+%Y%m%d')" - shell: bash - - - name: Cache Conda env - uses: actions/cache@v3 - id: cache - with: - path: ${{ env.CONDA }}/envs - key: conda-${{ runner.os }}--${{ runner.arch }}--${{ steps.get-date.outputs.today }}-${{ hashFiles('argilla-v1/environment_dev.yml') }}-${{ env.CACHE_NUMBER }} - - name: Update environment - if: steps.cache.outputs.cache-hit != 'true' - run: mamba env update -n argilla -f environment_dev.yml - - - name: Cache pip 👜 - uses: actions/cache@v3 - with: - path: ~/.cache/pip - key: ${{ runner.os }}-pip-${{ env.CACHE_NUMBER }}-${{ hashFiles('pyproject.toml') }} + run: | + mamba env update -n argilla -f environment_dev.yml - name: Set huggingface hub credentials if: github.ref == 'refs/heads/main' || github.ref == 'refs/heads/develop' || startsWith(github.ref, 'refs/heads/releases') diff --git a/.github/workflows/argilla.docs.yml b/.github/workflows/extralit.docs.yml similarity index 93% rename from .github/workflows/argilla.docs.yml rename to .github/workflows/extralit.docs.yml index cdf34729c..43094f496 100644 --- a/.github/workflows/argilla.docs.yml +++ b/.github/workflows/extralit.docs.yml @@ -15,13 +15,13 @@ on: - "develop" - "docs/**" paths: - - ".github/workflows/argilla.docs.yml" - - "argilla/docs/**" - - "argilla/mkdocs.yml" + - ".github/workflows/extralit.docs.yml" + - "extralit/docs/**" + - "extralit/mkdocs.yml" defaults: run: - working-directory: argilla + working-directory: extralit permissions: contents: write @@ -45,12 +45,12 @@ jobs: - name: Install uv uses: astral-sh/setup-uv@v5 with: - pyproject-file: "argilla/pyproject.toml" + pyproject-file: "extralit/pyproject.toml" python-version: "3.10" enable-cache: true cache-local-path: ~/.cache/uv ignore-nothing-to-cache: true - cache-dependency-glob: "argilla/pdm.lock" + cache-dependency-glob: "extralit/pdm.lock" - name: Setup PDM uses: pdm-project/setup-pdm@v4 @@ -58,7 +58,7 @@ jobs: python-version: "3.10" cache: true cache-dependency-path: | - argilla/pdm.lock + extralit/pdm.lock - name: Install dependencies env: diff --git a/.github/workflows/argilla.yml b/.github/workflows/extralit.yml similarity index 91% rename from .github/workflows/argilla.yml rename to .github/workflows/extralit.yml index 9390143bb..6e94ed68c 100644 --- a/.github/workflows/argilla.yml +++ b/.github/workflows/extralit.yml @@ -1,4 +1,4 @@ -name: Build and publish the `argilla` sdk python package +name: Build and publish the `extralit` sdk python package concurrency: group: ${{ github.workflow }}-${{ github.sha }} @@ -13,12 +13,12 @@ on: - develop - releases/** paths: - - "argilla/**" + - "extralit/**" - ".github/workflows/argilla.*" pull_request: paths: - - "argilla/**" + - "extralit/**" - "argilla-server/**" permissions: @@ -45,7 +45,7 @@ jobs: runs-on: ubuntu-22.04 defaults: run: - working-directory: argilla + working-directory: extralit strategy: fail-fast: true matrix: @@ -55,19 +55,19 @@ jobs: - name: Install uv uses: astral-sh/setup-uv@v5 with: - pyproject-file: "argilla/pyproject.toml" + pyproject-file: "extralit/pyproject.toml" python-version: ${{ matrix.python-version }} enable-cache: true cache-local-path: ~/.cache/uv ignore-nothing-to-cache: true - cache-dependency-glob: "argilla/pdm.lock" + cache-dependency-glob: "extralit/pdm.lock" - name: Setup PDM uses: pdm-project/setup-pdm@v4 with: python-version: ${{ matrix.python-version }} cache: true cache-dependency-path: | - argilla/pdm.lock + extralit/pdm.lock - name: Install dependencies env: PDM_IGNORE_ACTIVE_VENV: 1 @@ -132,8 +132,8 @@ jobs: # Upload the package to be used in the next jobs only once if: ${{ matrix.python-version == '3.9' }} with: - name: argilla - path: argilla/dist + name: extralit + path: extralit/dist # This job will publish argilla package into PyPI repository publish_release: @@ -153,7 +153,7 @@ jobs: defaults: run: shell: bash -l {0} - working-directory: argilla + working-directory: extralit steps: - name: Checkout Code 🛎 @@ -162,7 +162,7 @@ jobs: - name: Update repo visualizer uses: githubocto/repo-visualizer@0.7.1 with: - root_path: "argilla/" + root_path: "extralit/" excluded_paths: "dist,build,node_modules,docs,tests,.swm,assets,.github,package-lock.json,pdm.lock" excluded_globs: "*.spec.js;**/*.{png,jpg,svg,md};**/!(*.module).ts,**/__pycache__/,**/__mocks__/,LICENSE*,**/.gitignore,**/*.egg-info/,**/.*/" output_file: "repo-visualizer.svg" @@ -178,16 +178,16 @@ jobs: - name: Download python package uses: actions/download-artifact@v4 with: - name: argilla - path: argilla/dist + name: extralit + path: extralit/dist - name: Setup PDM uses: pdm-project/setup-pdm@v4 with: cache: true - python-version-file: argilla/pyproject.toml + python-version-file: extralit/pyproject.toml cache-dependency-path: | - argilla/pdm.lock + extralit/pdm.lock - name: Read package info run: | diff --git a/.github/workflows/scorecard.yml b/.github/workflows/scorecard.yml index 9c17dffe8..b4815f495 100644 --- a/.github/workflows/scorecard.yml +++ b/.github/workflows/scorecard.yml @@ -13,7 +13,7 @@ on: # - cron: '15 16 * * 3' pull_request: paths: - - "argilla/**" + - "extralit/**" - "argilla-server/**" # Declare default permissions as read only. diff --git a/.gitignore b/.gitignore index fbf8410b6..c33d22179 100644 --- a/.gitignore +++ b/.gitignore @@ -153,4 +153,4 @@ src/**/server/static/ # App generated files argilla-server/src/argilla_server/static -argilla/site +extralit/site diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 86929db9b..e72f43c19 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,7 +3,7 @@ repos: rev: v4.6.0 hooks: - id: check-yaml - exclude: argilla/mkdocs.yml|examples/deployments/k8s + exclude: extralit/mkdocs.yml|examples/deployments/k8s - id: end-of-file-fixer exclude_types: [text, jupyter] - id: trailing-whitespace @@ -14,14 +14,14 @@ repos: - id: ruff-format ############################################################################## - # argilla specific hooks + # extralit SDK specific hooks ############################################################################## - repo: https://github.com/pycqa/autoflake rev: v2.2.1 hooks: - id: autoflake - name: "Remove unused imports and variables in argilla" - files: '^argilla/.*\.py$' + name: "Remove unused imports and variables" + files: '^extralit/.*\.py$' args: - "--remove-all-unused-imports" - "--remove-unused-variables" @@ -30,7 +30,7 @@ repos: rev: v0.4.8 hooks: - id: ruff - files: 'argilla/src/.*\.py$' + files: 'extralit/src/.*\.py$' args: - --fix - repo: https://github.com/Lucas-C/pre-commit-hooks @@ -38,17 +38,17 @@ repos: hooks: - id: insert-license name: "Insert license header in Python source files" - files: '^argilla/.*\.py$' - exclude: ^argilla/docs/snippets/ + files: '^extralit/.*\.py$' + exclude: ^extralit/docs/snippets/ args: - --license-filepath - - argilla/LICENSE_HEADER + - extralit/LICENSE_HEADER - --fuzzy-match-generates-todo - repo: https://github.com/kynan/nbstripout rev: 0.7.1 hooks: - id: nbstripout - files: '^argilla/.*\.ipynb$' + files: '^extralit/.*\.ipynb$' args: - --keep-count - --keep-output @@ -59,6 +59,7 @@ repos: # - --keep-execution-count # - --keep-metadata # - --keep-version + ############################################################################## # argilla-server specific hooks ############################################################################## @@ -92,6 +93,37 @@ repos: - argilla-server/LICENSE_HEADER - --fuzzy-match-generates-todo + ############################################################################## + # argilla-v1 SDK specific hooks + ############################################################################## + - repo: https://github.com/pycqa/autoflake + rev: v2.2.1 + hooks: + - id: autoflake + name: "Remove unused imports and variables" + files: '^argilla-v1/.*\.py$' + args: + - "--remove-all-unused-imports" + - "--remove-unused-variables" + - "--in-place" + - repo: https://github.com/charliermarsh/ruff-pre-commit + rev: v0.4.8 + hooks: + - id: ruff + files: 'argilla-v1/src/.*\.py$' + args: + - --fix + - repo: https://github.com/Lucas-C/pre-commit-hooks + rev: v1.5.5 + hooks: + - id: insert-license + name: "Insert license header in Python source files" + files: '^argilla-v1/.*\.py$' + args: + - --license-filepath + - extralit/LICENSE_HEADER + - --fuzzy-match-generates-todo + ############################################################################## # Helm lint hook ############################################################################## diff --git a/README.md b/README.md index b772b4eee..0f35876bf 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@

- Extralit + Extralit

diff --git a/Tiltfile b/Tiltfile index 3074b9513..76a89cfb9 100644 --- a/Tiltfile +++ b/Tiltfile @@ -91,7 +91,7 @@ docker_build( context='argilla-server/', build_args={'ENV': ENV, 'USERS_DB': USERS_DB}, dockerfile='argilla-server/docker/server/dev.argilla_server.dockerfile', - ignore=['examples/', 'argilla/', '.*', '**/__pycache__', '*.pyc', 'CHANGELOG.md'], + ignore=['examples/', 'extralit/', '.*', '**/__pycache__', '*.pyc', 'CHANGELOG.md'], live_update=[ # Sync the source code to the container sync('argilla-server/src/', '/home/argilla/src/'), @@ -166,15 +166,15 @@ helm_resource( ) # Extralit server -if not os.path.exists('argilla/dist/'): - local('pdm build', dir='argilla') +if not os.path.exists('extralit/dist/'): + local('pdm build', dir='extralit') docker_build( "{DOCKER_REPO}/extralit-server".format(DOCKER_REPO=DOCKER_REPO), - context='argilla/', - dockerfile='argilla/docker/extralit.dockerfile', + context='extralit/', + dockerfile='extralit/docker/extralit.dockerfile', ignore=['.*', 'argilla-frontend/', 'argilla-server/', '**/__pycache__', '*.pyc'], live_update=[ - sync('argilla/', '/home/extralit/'), + sync('extralit/', '/home/extralit/'), ] ) extralit_k8s_yaml = read_yaml_stream('examples/deployments/k8s/extralit-deployment.yaml') diff --git a/argilla-frontend/README.md b/argilla-frontend/README.md index aa684a45f..ecc782da7 100644 --- a/argilla-frontend/README.md +++ b/argilla-frontend/README.md @@ -1,5 +1,5 @@

- Extralit + Extralit
Extralit Frontend
diff --git a/argilla-frontend/translation/ja.js b/argilla-frontend/translation/ja.js index c4579bbdf..621acc24e 100644 --- a/argilla-frontend/translation/ja.js +++ b/argilla-frontend/translation/ja.js @@ -286,7 +286,7 @@ export default { guidesText: "以下のガイドを参照してください", pasteRepoIdPlaceholder: "リポジトリIDを貼り付け <例> stanfordnlp/imdb", demoLink: - "こちらのデモにログインしてArgillaを試してみましょう", + "こちらのデモにログインしてArgillaを試してみましょう", name: "データセット名", updatedAt: "更新日", createdAt: "作成日", diff --git a/argilla-server/.env.dev b/argilla-server/.env.dev index 0f23f7983..7ab1208b6 100644 --- a/argilla-server/.env.dev +++ b/argilla-server/.env.dev @@ -1,7 +1,7 @@ OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES # Needed by RQ to work with forked processes on MacOS ALEMBIC_CONFIG=src/argilla_server/alembic.ini ARGILLA_AUTH_SECRET_KEY=8VO7na5N/jQx+yP/N+HlE8q51vPdrxqlh6OzoebIyko= # With this we avoid using a different key every time the server is reloaded -ARGILLA_DATABASE_URL=sqlite+aiosqlite:///${HOME}/.argilla/argilla-dev.db?check_same_thread=False +ARGILLA_DATABASE_URL=sqlite+aiosqlite:///${HOME}/.extralit/argilla-dev.db?check_same_thread=False # S3 Configuration ARGILLA_S3_ENDPOINT=http://minio:9000 diff --git a/argilla-server/.env.test b/argilla-server/.env.test index 55d04fe76..7fba1a591 100644 --- a/argilla-server/.env.test +++ b/argilla-server/.env.test @@ -1,2 +1,2 @@ -ARGILLA_DATABASE_URL=sqlite+aiosqlite:///${HOME}/.argilla/argilla-test.db?check_same_thread=False +ARGILLA_DATABASE_URL=sqlite+aiosqlite:///${HOME}/.extralit/argilla-test.db?check_same_thread=False ARGILLA_REDIS_URL=redis://localhost:6379/1 # Using a different Redis database for testing diff --git a/argilla-server/CHANGELOG.md b/argilla-server/CHANGELOG.md index 34683e988..247da2673 100644 --- a/argilla-server/CHANGELOG.md +++ b/argilla-server/CHANGELOG.md @@ -23,17 +23,17 @@ These are the section headers that we use: - Enhanced CLI commands for adding documents to include reference and improved error handling. - Added `from_file` method to Document for creating instances from file paths or URLs. - -### Fixed -- LocalFileStorage implementation to mimic Minio or S3 storage. -- Fixed `argilla-hf-spaces` s3 environment files. -- Used `uv` in `argilla-server` and `argilla-hf-spaces` Dockerfiles - ### Changed - Adjustments to Dockerfiles for clarity and consistency. - Updated `argilla-server` Dockerfile to use `uv` for installing server dependencies. - Refactored API schemas to use `DocumentCreate` and `DocumentDelete` for better clarity. - Updated elasticsearch to 8.17.0 in `argilla-hf-spaces` Dockerfile. +- Changed `home_path` to `~/.extralit/` from `~/.argilla/` to align with new project structure. + +### Fixed +- LocalFileStorage implementation to mimic Minio or S3 storage. +- Fixed `argilla-hf-spaces` s3 environment files. +- Used `uv` in `argilla-server` and `argilla-hf-spaces` Dockerfiles ## [Argilla] [2.8.0](https://github.com/argilla-io/argilla/compare/v2.7.1...v2.8.0) diff --git a/argilla-server/src/argilla_server/api/handlers/v1/documents.py b/argilla-server/src/argilla_server/api/handlers/v1/documents.py index 3cb10dd23..3557a4f09 100644 --- a/argilla-server/src/argilla_server/api/handlers/v1/documents.py +++ b/argilla-server/src/argilla_server/api/handlers/v1/documents.py @@ -32,7 +32,7 @@ if TYPE_CHECKING: from argilla_server.models import Document -_LOGGER = logging.getLogger("documents") +_LOGGER = logging.getLogger(__name__) router = APIRouter(tags=["documents"]) diff --git a/argilla-server/src/argilla_server/api/handlers/v1/files.py b/argilla-server/src/argilla_server/api/handlers/v1/files.py index 4b09fd803..7887de9de 100644 --- a/argilla-server/src/argilla_server/api/handlers/v1/files.py +++ b/argilla-server/src/argilla_server/api/handlers/v1/files.py @@ -1,3 +1,17 @@ +# Copyright 2024-present, Extralit Labs, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import logging from typing import Optional @@ -11,20 +25,20 @@ from argilla_server.api.schemas.v1.files import ListObjectsResponse, ObjectMetadata from argilla_server.security import auth -_LOGGER = logging.getLogger("files") +_LOGGER = logging.getLogger(__name__) router = APIRouter(tags=["files"]) + @router.get("/file/{bucket}/{object:path}") async def get_file( *, - bucket: str, - object: str, + bucket: str, + object: str, version_id: Optional[str] = None, client: Minio = Depends(files.get_minio_client), - current_user: Optional[User] = Security(auth.get_optional_current_user) - ): - + current_user: Optional[User] = Security(auth.get_optional_current_user), +): # TODO Check if the current user is in the workspace to have access to the s3 bucket of the same name # if current_user is not None or current_user.role != "owner": # await authorize(current_user, FilePolicy.get(bucket)) @@ -33,39 +47,37 @@ async def get_file( file_response = files.get_object(client, bucket, object, version_id=version_id, include_versions=True) return StreamingResponse( - file_response.response, - media_type=file_response.metadata.content_type, - headers=file_response.http_headers + file_response.response, media_type=file_response.metadata.content_type, headers=file_response.http_headers ) except S3Error as se: _LOGGER.error(f"Error getting object '{bucket}/{object}': {se}") raise HTTPException(status_code=404, detail=f"No object at path '{object}' was found") from se - + except Exception as e: _LOGGER.error(f"Error getting object '{bucket}/{object}': {e}") raise HTTPException(status_code=500, detail=str(e)) from e - + @router.post("/file/{bucket}/{object:path}", response_model=ObjectMetadata) async def put_file( *, - bucket: str, - object: str, + bucket: str, + object: str, file: UploadFile = File(...), client: Minio = Depends(files.get_minio_client), - current_user: User = Security(auth.get_current_user) - ): - + current_user: User = Security(auth.get_current_user), +): # Check if the current user is in the workspace to have access to the s3 bucket of the same name await authorize(current_user, FilePolicy.put_object(bucket)) - + try: - response = files.put_object(client, bucket, object, - data=file.file, size=file.size, content_type=file.content_type) + response = files.put_object( + client, bucket, object, data=file.file, size=file.size, content_type=file.content_type + ) return response except S3Error as se: raise HTTPException(status_code=500, detail=f"Internal server error: {se.message}") from se - + @router.get("/files/{bucket}/{prefix:path}", response_model=ListObjectsResponse) async def list_objects( @@ -73,38 +85,43 @@ async def list_objects( bucket: str, prefix: str, include_version=True, - recursive = True, + recursive=True, start_after: Optional[str] = None, client: Minio = Depends(files.get_minio_client), current_user: User = Security(auth.get_current_user), - ): +): # Check if the current user is in the workspace to have access to the s3 bucket of the same name await authorize(current_user, FilePolicy.list(bucket)) try: - objects = files.list_objects(client, bucket, prefix=prefix, include_version=include_version, recursive=recursive, start_after=start_after) + objects = files.list_objects( + client, bucket, prefix=prefix, include_version=include_version, recursive=recursive, start_after=start_after + ) return objects except S3Error as se: _LOGGER.error(f"Error listing objects in '{bucket}/{prefix}': {se}") if se.code == "NoSuchBucket": - raise HTTPException(status_code=404, detail=f"Bucket '{bucket}' not found, please run `rg.Workspace.create('{bucket}')` to create the S3 bucket.") from se + raise HTTPException( + status_code=404, + detail=f"Bucket '{bucket}' not found, please run `rg.Workspace.create('{bucket}')` to create the S3 bucket.", + ) from se else: - raise HTTPException(status_code=404, detail=f"Cannot list objects as '{bucket}/{prefix}' is not found") from se + raise HTTPException( + status_code=404, detail=f"Cannot list objects as '{bucket}/{prefix}' is not found" + ) from se except Exception as e: raise e - @router.delete("/file/{bucket}/{object:path}") async def delete_files( *, - bucket: str, - object: str, + bucket: str, + object: str, version_id: Optional[str] = None, client: Minio = Depends(files.get_minio_client), - current_user: User = Security(auth.get_current_user) - ): - + current_user: User = Security(auth.get_current_user), +): # Check if the current user is in the workspace to have access to the s3 bucket of the same name await authorize(current_user, FilePolicy.delete(bucket)) @@ -115,4 +132,3 @@ async def delete_files( raise HTTPException(status_code=500, detail="Internal server error") from se except Exception as e: raise e - diff --git a/argilla-server/src/argilla_server/api/handlers/v1/models.py b/argilla-server/src/argilla_server/api/handlers/v1/models.py index 0aafcf6b1..539ac212c 100644 --- a/argilla-server/src/argilla_server/api/handlers/v1/models.py +++ b/argilla-server/src/argilla_server/api/handlers/v1/models.py @@ -1,8 +1,22 @@ +# Copyright 2024-present, Extralit Labs, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import logging from urllib.parse import urljoin import httpx -from fastapi import APIRouter, Depends, HTTPException +from fastapi import APIRouter, Depends from starlette.requests import Request from starlette.responses import StreamingResponse @@ -11,30 +25,32 @@ from argilla_server.settings import settings from argilla_server.errors import UnauthorizedError, BadRequestError -_LOGGER = logging.getLogger("models") +_LOGGER = logging.getLogger(__name__) router = APIRouter(tags=["models"]) client = httpx.AsyncClient(timeout=10.0) -@router.api_route("/models/{rest_of_path:path}", - methods=["GET", "POST", "PUT", "DELETE"], - response_class=StreamingResponse) -async def proxy(request: Request, rest_of_path: str, - current_user: User = Depends(auth.get_current_user)): + +@router.api_route( + "/models/{rest_of_path:path}", methods=["GET", "POST", "PUT", "DELETE"], response_class=StreamingResponse +) +async def proxy(request: Request, rest_of_path: str, current_user: User = Depends(auth.get_current_user)): url = urljoin(settings.extralit_url, rest_of_path) params = dict(request.query_params) - _LOGGER.info('PROXY %s %s', url, params) + _LOGGER.info("PROXY %s %s", url, params) - if 'workspace' not in params or not params['workspace']: + if "workspace" not in params or not params["workspace"]: raise BadRequestError("`workspace` is required in query parameters") if current_user: - params['username'] = current_user.username + params["username"] = current_user.username - if current_user.role != "owner" and not await current_user.is_member_of_workspace_name(params['workspace']): - raise UnauthorizedError(f"{current_user.username} is not authorized to access workspace {params['workspace']}") + if current_user.role != "owner" and not await current_user.is_member_of_workspace_name(params["workspace"]): + raise UnauthorizedError( + f"{current_user.username} is not authorized to access workspace {params['workspace']}" + ) if request.method == "GET": proxy_request = client.build_request("GET", url, params=params) @@ -66,12 +82,14 @@ async def stream_response(): return StreamingResponse(stream_response(), media_type="text/event-stream") + @router.on_event("startup") async def startup_event(): global client if client is None: client = httpx.AsyncClient(timeout=10.0) + @router.on_event("shutdown") async def shutdown(): await client.aclose() diff --git a/argilla-server/src/argilla_server/contexts/files.py b/argilla-server/src/argilla_server/contexts/files.py index fe7af4ad3..d22d17fcd 100644 --- a/argilla-server/src/argilla_server/contexts/files.py +++ b/argilla-server/src/argilla_server/contexts/files.py @@ -34,11 +34,10 @@ from argilla_server.api.schemas.v1.files import ListObjectsResponse, ObjectMetadata, FileObjectResponse from argilla_server.settings import settings -from argilla_server.api.schemas.v1.files import FileObjectResponse EXCLUDED_VERSIONING_PREFIXES = ["pdf"] -_LOGGER = logging.getLogger("argilla") +_LOGGER = logging.getLogger(__name__) class LocalFileStorage: diff --git a/argilla-server/src/argilla_server/contexts/hub.py b/argilla-server/src/argilla_server/contexts/hub.py index 74aec43ec..8f2a21f2b 100644 --- a/argilla-server/src/argilla_server/contexts/hub.py +++ b/argilla-server/src/argilla_server/contexts/hub.py @@ -412,7 +412,7 @@ def _push_extra_files_to_hub(self, repo_id: str, token: str) -> None: hf_api = HfApi(token=token) with TemporaryDirectory() as temporary_directory: - argilla_directory = os.path.join(temporary_directory, ".argilla") + argilla_directory = os.path.join(temporary_directory, ".extralit") os.makedirs(argilla_directory) self._create_version_file(argilla_directory) diff --git a/argilla-server/src/argilla_server/settings.py b/argilla-server/src/argilla_server/settings.py index 5156a3a05..ed61e3605 100644 --- a/argilla-server/src/argilla_server/settings.py +++ b/argilla-server/src/argilla_server/settings.py @@ -179,7 +179,7 @@ def set_enable_telemetry(cls, enable_telemetry: bool) -> bool: @field_validator("home_path", mode="before") @classmethod def set_home_path_default(cls, home_path: str): - return home_path or os.path.join(Path.home(), ".argilla") + return home_path or os.path.join(Path.home(), ".extralit") @field_validator("base_url") @classmethod diff --git a/argilla-v1/CHANGELOG.md b/argilla-v1/CHANGELOG.md index 13f25d9ea..062560fa9 100644 --- a/argilla-v1/CHANGELOG.md +++ b/argilla-v1/CHANGELOG.md @@ -14,13 +14,11 @@ These are the section headers that we use: * "Security" in case of vulnerabilities. --> -## [Extralit] [0.3.0](https://github.com/extralit/extralit/compare/v0.2.3...v0.3.0) - ## [Argilla] [2.0.0rc1](https://github.com/argilla-io/argilla/compare/v1.29.0...v2.0.0rc1) > [!NOTE] -> As per the release of our 2.0 SDK, this changelog is deprecated and will only contain potential bug fixes for the 1.x SDK, but it will not contain any new features. For the latest features and changes, please refer to the [2.0 SDK changelog](../argilla/CHANGELOG.md). +> As per the release of our 2.0 SDK, this changelog is deprecated and will only contain potential bug fixes for the 1.x SDK, but it will not contain any new features. For the latest features and changes, please refer to the [2.0 SDK changelog](../extralit/CHANGELOG.md). ## [Argilla] [1.29.0](https://github.com/argilla-io/argilla/compare/v1.28.0...v1.29.0) diff --git a/argilla-v1/environment_dev.yml b/argilla-v1/environment_dev.yml index ff9788af5..2929e5675 100644 --- a/argilla-v1/environment_dev.yml +++ b/argilla-v1/environment_dev.yml @@ -2,7 +2,6 @@ name: argilla channels: - conda-forge - - defaults dependencies: - python~=3.9.7 @@ -39,8 +38,8 @@ dependencies: # code formatting - pre-commit~=3.2.0 # v1,v2 transition, updated to v2 in run-python-tests.yml - - pydantic>=1.10.7 - - pandas>=1.0.0 + - pydantic>=1.10.7,<2 + - pandas>=1.0.0,<2 # extra test dependencies - cleanlab~=2.0.0 # With this version, tests are failing - datasets>1.17.0,!= 2.3.2 # TODO: push_to_hub fails up to 2.3.2, check patches when they come out eventually diff --git a/argilla-v1/pyproject.toml b/argilla-v1/pyproject.toml index 244bdb32b..4edc9a873 100644 --- a/argilla-v1/pyproject.toml +++ b/argilla-v1/pyproject.toml @@ -30,12 +30,16 @@ dependencies = [ "httpx >= 0.15,<= 0.26", "deprecated ~= 1.2.0", "packaging >= 20.0", + # pandas -> For data loading + "pandas >=1.0.0", # Aligned pydantic version with server fastAPI - "pydantic >= 1.10.7", + "pydantic >= 1.10.7,<2", # monitoring "wrapt >= 1.14,< 1.15", + # weaksupervision + "numpy < 1.27.0", # for progressbars - "tqdm >= 4.66.1, < 5.0.0", + "tqdm >= 4.27.0", # monitor background consumers "backoff", "monotonic", @@ -51,9 +55,7 @@ dependencies = [ "fastapi < 1.0.0", "pypandoc ~= 1.13", "beautifulsoup4 ~= 4.12.2", - "pandas ~= 2.2.2", "pandera[io] ~= 0.19.3", - "numpy ~= 1.26.4", "spacy ~= 3.7.2", "pyarrow == 14.*", "natsort ~= 8.4.0", @@ -61,8 +63,8 @@ dependencies = [ "dill ~= 0.3.8", "json-repair ~= 0.19.2", "fastparquet", - "tiktoken", - "pymupdf", + "tiktoken ~= 0.9.0", + "pymupdf==1.26.0", # for llama-index "llama-index ~= 0.10.68", "llama-index-core ~= 0.10.68", @@ -70,9 +72,6 @@ dependencies = [ "llama-index-llms-openai ~= 0.1.31", "llama-index-embeddings-openai ~=0.1.11", "llama-index-multi-modal-llms-openai", - # for weaviate vector db - "weaviate-client >= 4", - "llama-index-vector-stores-weaviate ~= 1.0.0", ] dynamic = ["version"] # This line is just to force the build diff --git a/argilla-v1/src/argilla_v1/cli/extraction/__main__.py b/argilla-v1/src/argilla_v1/cli/extraction/__main__.py index cfd02869a..526500860 100644 --- a/argilla-v1/src/argilla_v1/cli/extraction/__main__.py +++ b/argilla-v1/src/argilla_v1/cli/extraction/__main__.py @@ -1,4 +1,3 @@ - from typing import Optional from argilla_v1.cli.callback import init_callback @@ -9,12 +8,14 @@ _COMMANDS_REQUIRING_WORKSPACE = ["export"] _COMMANDS_REQUIRING_ENVFILE = ["export"] + def callback( ctx: typer.Context, workspace: str = typer.Option(None, help="Name of the workspace to which apply the command."), env_file: str = typer.Option(None, help="Path to .env file with environment variables containing S3 credentials."), ) -> None: from argilla_v1.client.singleton import active_client + init_callback() if ctx.invoked_subcommand not in _COMMANDS_REQUIRING_ENVFILE: @@ -24,7 +25,9 @@ def callback( return if workspace is None: - raise typer.BadParameter("The command requires a workspace name provided using '--workspace' option the {typer.style(ctx.invoked_subcommand, bold=True)} keyword") + raise typer.BadParameter( + "The command requires a workspace name provided using '--workspace' option the {typer.style(ctx.invoked_subcommand, bold=True)} keyword" + ) elif env_file is None: raise typer.BadParameter("The command requires a .env file path provided using '--env-file' option") @@ -49,12 +52,13 @@ def callback( success=False, ) raise typer.Exit(code=1) from e - + from dotenv import load_dotenv + if env_file is not None: load_dotenv(env_file) - from extralit.server.context.files import get_minio_client + from extralit_v1.server.context.files import get_minio_client minio_client = get_minio_client() ctx.obj = { diff --git a/argilla-v1/src/argilla_v1/cli/extraction/export.py b/argilla-v1/src/argilla_v1/cli/extraction/export.py index eedd4b662..80b9b75b0 100644 --- a/argilla-v1/src/argilla_v1/cli/extraction/export.py +++ b/argilla-v1/src/argilla_v1/cli/extraction/export.py @@ -1,11 +1,10 @@ - - from typing import Dict, Optional import typer from argilla_v1.client.enums import DatasetType -from extralit.server.context.files import get_minio_client +from extralit_v1.server.context.files import get_minio_client + def export_data( ctx: typer.Context, @@ -17,4 +16,3 @@ def export_data( ) -> None: print("export_data", ctx.obj, type_) print(get_minio_client()) - \ No newline at end of file diff --git a/argilla-v1/src/argilla_v1/cli/schemas/delete.py b/argilla-v1/src/argilla_v1/cli/schemas/delete.py index cdc3971b4..a99e03885 100644 --- a/argilla-v1/src/argilla_v1/cli/schemas/delete.py +++ b/argilla-v1/src/argilla_v1/cli/schemas/delete.py @@ -1,11 +1,24 @@ +# Copyright 2024-present, Extralit Labs, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os -from pathlib import Path -from typing import Dict, Optional, List -from extralit.constants import DEFAULT_SCHEMA_S3_PATH +from typing import Optional import typer from argilla_v1.client.workspaces import Workspace -from extralit.extraction.models import SchemaStructure +from extralit_v1.extraction.models.schema import DEFAULT_SCHEMA_S3_PATH + def delete_schema( ctx: typer.Context, @@ -26,8 +39,7 @@ def delete_schema( hidden=True, ), ) -> None: - from argilla_v1.cli.rich import echo_in_panel, get_argilla_themed_table - from rich.console import Console + from argilla_v1.cli.rich import echo_in_panel try: workspace: Workspace = ctx.obj["workspace"] @@ -42,7 +54,7 @@ def delete_schema( except Exception as e: echo_in_panel( - f"Unable to list schemas in workspace:\n{e}", + f"Unable to list schemas in workspace:\n{e}", title="Unexpected error", title_align="left", success=False, diff --git a/argilla-v1/src/argilla_v1/cli/schemas/list.py b/argilla-v1/src/argilla_v1/cli/schemas/list.py index 06e36418b..cfa351e02 100644 --- a/argilla-v1/src/argilla_v1/cli/schemas/list.py +++ b/argilla-v1/src/argilla_v1/cli/schemas/list.py @@ -1,10 +1,23 @@ +# Copyright 2024-present, Extralit Labs, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from pathlib import Path -from typing import Dict, Optional, List -from extralit.constants import DEFAULT_SCHEMA_S3_PATH +from typing import Optional import typer from argilla_v1.client.workspaces import Workspace -from extralit.extraction.models import SchemaStructure +from extralit_v1.extraction.models.schema import DEFAULT_SCHEMA_S3_PATH def list_schemas( @@ -69,13 +82,13 @@ def list_schemas( if csv_path: df = console_table_to_pandas_df(table) df.to_csv(csv_path, index=False) - + else: console.print(table) except Exception as e: echo_in_panel( - f"Unable to list schemas in workspace:\n{e}", + f"Unable to list schemas in workspace:\n{e}", title="Unexpected error", title_align="left", success=False, diff --git a/argilla-v1/src/argilla_v1/cli/schemas/upload.py b/argilla-v1/src/argilla_v1/cli/schemas/upload.py index 087589726..cde0768fd 100644 --- a/argilla-v1/src/argilla_v1/cli/schemas/upload.py +++ b/argilla-v1/src/argilla_v1/cli/schemas/upload.py @@ -3,7 +3,8 @@ import typer from argilla_v1.client.workspaces import Workspace -from extralit.extraction.models import SchemaStructure +from extralit_v1.extraction.models import SchemaStructure + def upload_schemas( ctx: typer.Context, @@ -37,7 +38,7 @@ def upload_schemas( if not update_schemas.schemas: raise FileNotFoundError(f"No schemas found in directory '{path}'.") - + uploaded_files = workspace.update_schemas(update_schemas, check_existing=not overwrite) if uploaded_files.objects: @@ -66,7 +67,7 @@ def upload_schemas( except Exception as e: echo_in_panel( - f"Unable to update schemas in workspace:\n{e}", + f"Unable to update schemas in workspace:\n{e}", title="Unexpected error", title_align="left", success=False, diff --git a/argilla-v1/src/argilla_v1/client/feedback/schemas/documents.py b/argilla-v1/src/argilla_v1/client/feedback/schemas/documents.py index ba644bffc..0c8e1b1d4 100644 --- a/argilla-v1/src/argilla_v1/client/feedback/schemas/documents.py +++ b/argilla-v1/src/argilla_v1/client/feedback/schemas/documents.py @@ -5,9 +5,10 @@ from typing import Any, Dict, List, Literal, Optional, Union import uuid -from argilla.pydantic_v1 import BaseModel, Field, Extra +from argilla_v1.pydantic_v1 import BaseModel, Field, Extra from uuid import UUID + class Document(BaseModel, ABC): """Schema for the `Document` model. @@ -20,7 +21,10 @@ class Document(BaseModel, ABC): workspace_id: The workspace ID of the document. Required. """ - id: Union[UUID, str] = Field(default_factory=uuid.uuid4(), description="The ID of the document, which gets assigned randomly if not provided.") + id: Union[UUID, str] = Field( + default_factory=uuid.uuid4(), + description="The ID of the document, which gets assigned randomly if not provided.", + ) file_name: str = Field(...) reference: Optional[str] = None doi: Optional[str] = None @@ -32,12 +36,19 @@ class Document(BaseModel, ABC): class Config: validate_assignment = True extra = Extra.forbid - json_encoders = { - UUID: str - } + json_encoders = {UUID: str} @classmethod - def from_file(cls, file_path: str, *, reference: str, id: Optional[str] = None, pmid: Optional[str] = None, doi: Optional[str] = None, workspace_id: Optional[UUID] = None) -> "Document": + def from_file( + cls, + file_path: str, + *, + reference: str, + id: Optional[str] = None, + pmid: Optional[str] = None, + doi: Optional[str] = None, + workspace_id: Optional[UUID] = None, + ) -> "Document": url = None if os.path.exists(file_path): @@ -48,7 +59,7 @@ def from_file(cls, file_path: str, *, reference: str, id: Optional[str] = None, url = file_path parsed_url = urlparse(file_path) path = parsed_url.path - file_name = unquote(path).split('/')[-1] + file_name = unquote(path).split("/")[-1] else: raise ValueError(f"File path {file_path} does not exist") @@ -58,7 +69,7 @@ def from_file(cls, file_path: str, *, reference: str, id: Optional[str] = None, file_name=file_name if isinstance(file_name, str) else None, url=url if isinstance(url, str) else None, id=id or uuid.uuid4(), - pmid=str(pmid) if isinstance(pmid, int) or isinstance(pmid, str) and len(pmid)>3 else None, + pmid=str(pmid) if isinstance(pmid, int) or isinstance(pmid, str) and len(pmid) > 3 else None, doi=doi if isinstance(doi, str) else None, workspace_id=workspace_id, ) @@ -81,4 +92,4 @@ def to_server_payload(self) -> Dict[str, Any]: return json def __repr__(self) -> str: - return f"{self.__class__.__name__}(id={self.id!r}, file_name={self.file_name!r}, pmid={self.pmid!r}, doi={self.doi!r}, workspace_id={self.workspace_id!r})" \ No newline at end of file + return f"{self.__class__.__name__}(id={self.id!r}, file_name={self.file_name!r}, pmid={self.pmid!r}, doi={self.doi!r}, workspace_id={self.workspace_id!r})" diff --git a/argilla-v1/src/argilla_v1/client/feedback/schemas/records.py b/argilla-v1/src/argilla_v1/client/feedback/schemas/records.py index 969fdf974..bcb46561d 100644 --- a/argilla-v1/src/argilla_v1/client/feedback/schemas/records.py +++ b/argilla-v1/src/argilla_v1/client/feedback/schemas/records.py @@ -18,7 +18,8 @@ from argilla_v1.client.feedback.schemas.enums import RecordSortField, SortOrder -import argilla as rg +import argilla_v1 as rg + # Support backward compatibility for import of RankingValueSchema from records module from argilla_v1.client.feedback.schemas.response_values import RankingValueSchema # noqa from argilla_v1.client.feedback.schemas.responses import ResponseSchema, ValueSchema # noqa @@ -104,7 +105,7 @@ def __repr_args__(self): repr_args[i] = (name, [v.question_name for v in value]) return repr_args - + @validator("suggestions", always=True) def normalize_suggestions(cls, values: Any) -> Tuple: if not isinstance(values, tuple): @@ -149,7 +150,8 @@ def to_server_payload(self, question_name_to_id: Optional[Dict[str, UUID]] = Non payload["responses"] = [response.to_server_payload() for response in self.responses] if question_name_to_id: payload["suggestions"] = [ - suggestion.to_server_payload(question_name_to_id) for suggestion in self.suggestions \ + suggestion.to_server_payload(question_name_to_id) + for suggestion in self.suggestions if suggestion.question_name in question_name_to_id ] diff --git a/argilla-v1/src/argilla_v1/client/workspaces.py b/argilla-v1/src/argilla_v1/client/workspaces.py index 05dc68fc1..a55cbe012 100644 --- a/argilla-v1/src/argilla_v1/client/workspaces.py +++ b/argilla-v1/src/argilla_v1/client/workspaces.py @@ -18,7 +18,6 @@ from typing import TYPE_CHECKING, Iterator, List, Optional, Union from uuid import UUID -from extralit.constants import DEFAULT_SCHEMA_S3_PATH from argilla_v1.client.sdk.commons.errors import ( AlreadyExistsApiError, @@ -37,7 +36,7 @@ from argilla_v1.client.sdk.v1.files.models import FileObjectResponse, ListObjectsResponse import pandera as pa -from extralit.extraction.models import SchemaStructure +from extralit_v1.extraction.models.schema import SchemaStructure, DEFAULT_SCHEMA_S3_PATH if TYPE_CHECKING: import httpx @@ -46,6 +45,7 @@ _LOGGER = logging.getLogger(__name__) + class Workspace: """The `Workspace` class is used to manage workspaces in Argilla. It provides methods to create new workspaces, add users to them, list the linked users, @@ -267,9 +267,10 @@ def delete(self) -> None: ) from e except BaseClientError as e: raise RuntimeError(f"Error while deleting workspace with id {self.id!r}.") from e - - def get_schemas(self, prefix: str = DEFAULT_SCHEMA_S3_PATH, exclude: Optional[List[str]]=None) -> "SchemaStructure": - + + def get_schemas( + self, prefix: str = DEFAULT_SCHEMA_S3_PATH, exclude: Optional[List[str]] = None + ) -> "SchemaStructure": """ Get the schemas from the workspace. @@ -280,7 +281,8 @@ def get_schemas(self, prefix: str = DEFAULT_SCHEMA_S3_PATH, exclude: Optional[Li schemas = {} try: workspace_files: ListObjectsResponse = workspaces_api_v1.list_workspace_files( - self._client, workspace_name=self.name, path=prefix).parsed + self._client, workspace_name=self.name, path=prefix + ).parsed for file_metadata in workspace_files.objects: _, file_ext = os.path.splitext(file_metadata.object_name) @@ -289,9 +291,13 @@ def get_schemas(self, prefix: str = DEFAULT_SCHEMA_S3_PATH, exclude: Optional[Li try: file_content: str = workspaces_api_v1.get_workspace_file( - self._client, workspace_name=self.name, path=file_metadata.object_name, version_id=file_metadata.version_id).parsed + self._client, + workspace_name=self.name, + path=file_metadata.object_name, + version_id=file_metadata.version_id, + ).parsed schema = pa.io.from_json(file_content) - + if schema.name in schemas or (exclude and schema.name in exclude): continue @@ -302,12 +308,10 @@ def get_schemas(self, prefix: str = DEFAULT_SCHEMA_S3_PATH, exclude: Optional[Li schemas[schema.name] = schema except Exception as e: - raise RuntimeError( - f"Error getting schemas from workspace with name=`{self.name}` due to: `{e}`." - ) from e + raise RuntimeError(f"Error getting schemas from workspace with name=`{self.name}` due to: `{e}`.") from e return SchemaStructure(schemas=list(schemas.values())) - + @allowed_for_roles(roles=[UserRole.owner, UserRole.admin]) def list_files(self, path: str, recursive=True, include_version=True) -> ListObjectsResponse: """ @@ -321,14 +325,13 @@ def list_files(self, path: str, recursive=True, include_version=True) -> ListObj """ try: return workspaces_api_v1.list_workspace_files( - self._client, workspace_name=self.name, path=path, recursive=recursive, include_version=include_version).parsed + self._client, workspace_name=self.name, path=path, recursive=recursive, include_version=include_version + ).parsed except Exception as e: - raise RuntimeError( - f"Error listing files in workspace with name=`{self.name}` due to: `{e}`." - ) from e - + raise RuntimeError(f"Error listing files in workspace with name=`{self.name}` due to: `{e}`.") from e + @allowed_for_roles(roles=[UserRole.owner, UserRole.admin]) - def delete_file(self, path: str, version_id: Optional[str]=None) -> None: + def delete_file(self, path: str, version_id: Optional[str] = None) -> None: """ Deletes schemas from the workspace. @@ -337,13 +340,11 @@ def delete_file(self, path: str, version_id: Optional[str]=None) -> None: """ try: workspaces_api_v1.delete_workspace_file( - self._client, workspace_name=self.name, path=path, version_id=version_id) + self._client, workspace_name=self.name, path=path, version_id=version_id + ) except Exception as e: - raise RuntimeError( - f"Error deleting file `{path}` from workspace" - ) from e + raise RuntimeError(f"Error deleting file `{path}` from workspace") from e - @allowed_for_roles(roles=[UserRole.owner, UserRole.admin]) def add_schema(self, schema: Optional["pa.DataFrameSchema"], prefix: str = DEFAULT_SCHEMA_S3_PATH): """ @@ -357,23 +358,28 @@ def add_schema(self, schema: Optional["pa.DataFrameSchema"], prefix: str = DEFAU """ assert schema is not None, "Schema cannot be None." object_path = os.path.join(prefix, schema.name) - file_path = Path('/tmp') / schema.name - with open(file_path, 'w', encoding='utf-8') as f: + file_path = Path("/tmp") / schema.name + with open(file_path, "w", encoding="utf-8") as f: f.write(schema.to_json()) - + try: - if workspaces_api_v1.exist_workspace_file(self._client, workspace_name=self.name, path=object_path, file_path=file_path): + if workspaces_api_v1.exist_workspace_file( + self._client, workspace_name=self.name, path=object_path, file_path=file_path + ): raise ValueError(f"Schema with name=`{schema.name}` already exists in workspace with id=`{self.id}`.") - - workspaces_api_v1.put_workspace_file(self._client, workspace_name=self.name, path=object_path, file_path=file_path) + + workspaces_api_v1.put_workspace_file( + self._client, workspace_name=self.name, path=object_path, file_path=file_path + ) except Exception as e: raise RuntimeError( f"Error adding schema with name=`{schema.name}` to workspace with id=`{self.id}`." ) from e - @allowed_for_roles(roles=[UserRole.owner, UserRole.admin]) - def update_schemas(self, schemas: "SchemaStructure", check_existing=True, prefix: str = DEFAULT_SCHEMA_S3_PATH) -> ListObjectsResponse: + def update_schemas( + self, schemas: "SchemaStructure", check_existing=True, prefix: str = DEFAULT_SCHEMA_S3_PATH + ) -> ListObjectsResponse: """ Updates existing schemas in the workspace. @@ -385,23 +391,27 @@ def update_schemas(self, schemas: "SchemaStructure", check_existing=True, prefix output_metadata = [] for schema in schemas.schemas: object_path = os.path.join(prefix, schema.name) - file_path = Path('/tmp') / schema.name - with open(file_path, 'w', encoding='utf-8') as f: + file_path = Path("/tmp") / schema.name + with open(file_path, "w", encoding="utf-8") as f: f.write(schema.to_json()) - + try: - if check_existing and workspaces_api_v1.exist_workspace_file(self._client, workspace_name=self.name, path=object_path, file_path=file_path): - _LOGGER.info(f"Skipping schema name='{schema.name}' update since it's unmodified in workspace with name='{self.name}'.") + if check_existing and workspaces_api_v1.exist_workspace_file( + self._client, workspace_name=self.name, path=object_path, file_path=file_path + ): + _LOGGER.info( + f"Skipping schema name='{schema.name}' update since it's unmodified in workspace with name='{self.name}'." + ) continue - response = workspaces_api_v1.put_workspace_file(self._client, workspace_name=self.name, path=object_path, file_path=file_path) + response = workspaces_api_v1.put_workspace_file( + self._client, workspace_name=self.name, path=object_path, file_path=file_path + ) output_metadata.append(response.parsed) except Exception as e: - raise RuntimeError( - f"Error adding schema '{schema.name}' to workspace due to `{e}`." - ) from e - + raise RuntimeError(f"Error adding schema '{schema.name}' to workspace due to `{e}`.") from e + return ListObjectsResponse(objects=output_metadata) - + @staticmethod def __active_client() -> "httpx.Client": """Returns the active Argilla `httpx.Client` instance.""" diff --git a/argilla-v1/src/extralit_v1/__init__.py b/argilla-v1/src/extralit_v1/__init__.py new file mode 100644 index 000000000..b5fdc7530 --- /dev/null +++ b/argilla-v1/src/extralit_v1/__init__.py @@ -0,0 +1 @@ +__version__ = "0.2.2" diff --git a/argilla/tests/extralit/metrics/__init__.py b/argilla-v1/src/extralit_v1/convert/__init__.py similarity index 100% rename from argilla/tests/extralit/metrics/__init__.py rename to argilla-v1/src/extralit_v1/convert/__init__.py diff --git a/argilla-v1/src/extralit_v1/convert/html_table.py b/argilla-v1/src/extralit_v1/convert/html_table.py new file mode 100644 index 000000000..0395bcbd8 --- /dev/null +++ b/argilla-v1/src/extralit_v1/convert/html_table.py @@ -0,0 +1,217 @@ +import html +import io +import logging +import re +from collections import defaultdict +from typing import Union + +import pandas as pd +from bs4 import BeautifulSoup + +from extralit_v1.convert.text import remove_markdown_from_string + + +def html_table_to_json(s: str) -> str: + if not isinstance(s, io.StringIO): + s = io.StringIO(s) + + df = html_to_df(s, convert_spanning_rows=False) + df.columns = df.columns.astype(str).str.replace(".", "") + + df_json = df.to_json(orient="table", index=bool(df.index.name) or len(df.index.names) > 1) + return df_json + + +def html_to_df( + s: Union[io.StringIO, str], + flatten_columns=True, + convert_spanning_rows=False, + remove_markdown=False, + rename_duplicate_columns=True, +) -> pd.DataFrame: + if not isinstance(s, io.StringIO): + s = io.StringIO(s.replace("

", "
")) + + df = pd.read_html(s)[0] + df_str_cols = df.select_dtypes(include=["O", "string"]) + df.loc[:, df_str_cols.columns] = df_str_cols.map( + lambda x: html.unescape(x.strip()) if isinstance(x, str) else x, na_action="ignore" + ) + + if isinstance(df.columns, pd.MultiIndex): + new_columns = [tuple(re.sub(r"\.\d+$", "", level) for level in levels) for levels in df.columns] + df.columns = pd.MultiIndex.from_tuples(new_columns) + + df.columns = df.columns.map( + lambda column: re.sub(r"Unnamed: \d+", "Variable", column) if isinstance(column, str) else column + ) + + if flatten_columns: + df.columns = flatten_multilevel_columns(df.columns) + + if convert_spanning_rows: + df = convert_spanning_rows_to_group_column(df, new_column="Group") + + if rename_duplicate_columns: + df.columns = rename_to_unique_columns(df.columns) + + if remove_markdown: + df = remove_markdown_from_data_frame(df) + + return df + + +def convert_spanning_rows_to_group_column(df: pd.DataFrame, new_column="Group") -> pd.DataFrame: + # These rows have the same value across all columns + mask = df.iloc[:, :].apply(lambda row: row.nunique() == 1, axis=1) + spanning_rows = df[mask] + mask_consecutive = mask.groupby((~mask).cumsum()).transform("size") * mask > 1 + + if spanning_rows.empty: + return df + elif mask[0] is False or any(mask_consecutive): + logging.info( + f"Skipping pivot of spanning rows, since the first row is not spanning or there are multiple consecutive spans." + ) + return df + + if new_column in df.columns: + new_column = "Subgroup" + + # Determine group names and the range of rows they span + last_spanning_row_index = 0 + for index, row in spanning_rows.iterrows(): + if last_spanning_row_index is not None: + # Assign group name to the range of rows between the last and current spanning rows + group_name = df.iloc[last_spanning_row_index, 0] + df.loc[last_spanning_row_index + 1 : index - 1, new_column] = group_name + last_spanning_row_index = index + + # Add the last group if there's any data after the last spanning row + if last_spanning_row_index is not None and last_spanning_row_index + 1 < len(df): + group_name = df.iloc[last_spanning_row_index, 0] + df.loc[last_spanning_row_index + 1 :, new_column] = group_name + + # Remove the spanning rows + df = df[~mask].reset_index(drop=True) + + # Reorder columns to have new_column as the first column + df.insert(0, new_column, df.pop(new_column)) + + return df + + +def flatten_multilevel_columns(columns: Union[pd.MultiIndex, pd.Index], sep=" - ") -> pd.Index: + if isinstance(columns, pd.MultiIndex): + new_columns = columns.map( + lambda levels: sep.join( + list(dict.fromkeys(str(level) for level in levels if level)) + if isinstance(levels, tuple) + else str(levels) + ) + ) + else: + new_columns = columns + + return new_columns + + +def rename_to_unique_columns(columns: Union[pd.Index, pd.MultiIndex]) -> Union[pd.Index, pd.MultiIndex]: + column_counts = defaultdict(int) + + if isinstance(columns, pd.MultiIndex): + new_columns = [] + for levels in columns: + if levels in column_counts: + column_counts[levels] += 1 + if column_counts[levels] > 1: + levels[-1] = f"{levels[-1]}.{column_counts[levels]}" + else: + column_counts[levels] = 0 + + new_columns.append(tuple(levels)) + + return pd.MultiIndex.from_tuples(new_columns) + + else: + new_columns = [] + column_counts = {} + for col in columns: + if col in column_counts: + column_counts[col] += 1 + new_column = f"{col}.{column_counts[col]}" + else: + column_counts[col] = 0 + new_column = col + new_columns.append(new_column) + + return pd.Index(new_columns) + + +def remove_markdown_from_data_frame(df: pd.DataFrame) -> pd.DataFrame: + df = df.map(remove_markdown_from_string) + + if isinstance(df.columns, pd.MultiIndex): + df.columns = pd.MultiIndex.from_tuples( + [tuple(remove_markdown_from_string(label) for label in col) for col in df.columns], names=df.columns.names + ) + else: + df.columns = [remove_markdown_from_string(col) for col in df.columns] + + return df + + +def remove_html_styles(html_str): + style_pattern = r'\s*style="[^"]*"' + html_str = re.sub(style_pattern, "", html_str) + + class_pattern = r'\s*class="[^"]*"' + html_str = re.sub(class_pattern, "", html_str) + + tag_pattern = r"<\/?span[^>]*>|<\/?em[^>]*>" + html_str = re.sub(tag_pattern, "", html_str) + + return html_str + + +def fix_llmsherpa_html_table(html): + soup = BeautifulSoup(html, "html.parser") + if not soup.table: + return html + + # Ensure exists. If not, wrap the first in + if not soup.thead: + for th in soup.find_all("th"): + if not th.find_parent("thead"): + thead = soup.new_tag("thead") + th.wrap(thead) + + # Convert to in + if soup.thead: + for td in soup.thead.find_all("td"): + td.name = "th" + + # Ensure exists for the remaining rows + if not soup.tbody and soup.thead: + tbody = soup.new_tag("tbody") + # Exclude rows already inside + for tr in soup.table.find_all("tr", recursive=False): + if tr not in soup.thead: + tbody.append(tr) + soup.table.append(tbody) + + return str(soup) + + +def llmsherpa_html_to_df(html: io.StringIO): + html = fix_llmsherpa_html_table(html) + df = pd.read_html(io.StringIO(html) if not isinstance(html, io.StringIO) else html)[0] + # df = df.dropna(axis='columns', how='all') + + columns = df.columns.to_frame() + columns = columns.fillna("Unnamed") + df.columns = pd.MultiIndex.from_frame(columns) + + df.columns = rename_to_unique_columns(df.columns) + + return df diff --git a/argilla-v1/src/extralit_v1/convert/json_table.py b/argilla-v1/src/extralit_v1/convert/json_table.py new file mode 100644 index 000000000..d24a399da --- /dev/null +++ b/argilla-v1/src/extralit_v1/convert/json_table.py @@ -0,0 +1,232 @@ +import io +import json +import logging +import re +from typing import List, Optional, Union, Dict, Literal, Any + +import pandas as pd +import pandera as pa + +from extralit_v1.schema.checks.utils import make_same_length_arguments +from extralit_v1.server.models.extraction import ExtractionResponse + +_LOGGER = logging.getLogger(__name__) + + +def drop_single_value_index_names(df: pd.DataFrame) -> pd.Index: + names_to_drop = [] + for name in df.index.names: + num_uniques = df.index.get_level_values(name).nunique() + + if num_uniques <= 1: + names_to_drop.append(name) + + if names_to_drop and len(names_to_drop) == len(df.index.names): + names_to_drop = names_to_drop[1:] + + return pd.Index(names_to_drop) + + +def preprocess(df: pd.DataFrame, required_columns: List[str] = [], drop_columns: Optional[List[str]] = None): + """ + Preprocess a DataFrame before utils. + """ + assert isinstance(required_columns, list), "required_columns must be a list" + + if drop_columns: + drop_columns = pd.Index(drop_columns).difference(required_columns) + drop_index_levels = set(df.index.names or [df.index.name]) & set(drop_columns) + if drop_index_levels: + df.index = df.index.droplevel(list(drop_index_levels)) + df = df.drop(columns=drop_columns, errors="ignore") + + # Drop columns with same name as any index name or multiindex names + df = df.drop(columns=df.columns.intersection(df.index.names), errors="ignore") + + # Replace values in string columns + df_str_cols = df.select_dtypes(include=["O", "string"]) + df.loc[:, df_str_cols.columns] = df_str_cols.replace({None: "NA", "": "NA"}) + + # Drop columns with all "NA" + # all_na_columns = df.columns.difference(required_columns) + # all_na_columns = all_na_columns[(df[all_na_columns].isna() | (df[all_na_columns] == 'NA')).all(axis=0)] + # df = df.drop(columns=all_na_columns, errors='ignore') + + return df + + +def standardize_values(df: pd.DataFrame, schema: pa.DataFrameSchema) -> pd.DataFrame: + # Capitalize string values if the schema has `isin` checks with capital letters + for column_name, columnSchema in schema.columns.items(): + if column_name not in df.columns: + continue + isin_checks = [check for check in columnSchema.checks if isinstance(check, pa.Check) and hasattr(check, "isin")] + if not isin_checks or "allowed_values" not in isin_checks[0].statistics: + continue + + if all([value.istitle() for value in isin_checks[0].statistics["allowed_values"]]): + df[column_name] = df[column_name].apply(lambda x: x.capitalize() if pd.notna(x) and x != "NA" else x) + + return df + + +def get_required_columns(schema: pa.DataFrameSchema) -> List[str]: + required_columns = [name for name, column in schema.columns.items() if not column.nullable] + index_columns = [index.name for index in schema.index.indexes] if hasattr(schema.index, "indexes") else [] + required_columns = required_columns + index_columns + return required_columns + + +def df_to_json( + df: pd.DataFrame, + schema: pa.DataFrameSchema, + drop_columns=None, + transpose=False, + metadata: Optional[Dict[str, Any]] = None, + **kwargs, +) -> str: + """ + Convert a DataFrame to a JSON string. If a schema is provided, the DataFrame will be validated against the schema. + + Args: + df: pd.DataFrame + schema: DataFrameSchema + drop_columns: List of columns to drop + transpose: Transpose the DataFrame + metadata: Additional metadata to include in the JSON + + Returns: + JSON string + """ + assert isinstance(schema, pa.DataFrameSchema), "schema must be a DataFrameSchema" + required_columns = get_required_columns(schema) + + df = preprocess(df, required_columns=required_columns, drop_columns=drop_columns) + + try: + df = standardize_values(df, schema) + except Exception as e: + print("Failed to standardize values:", e) + + if transpose: + df = df.T + df.index.name = df.index.name or "reference" + + try: + df_json = json.loads(df.to_json(orient="table", index=bool(df.index.name) or len(df.index.names) > 1)) + except Exception as e: + print("Failed to convert DataFrame to JSON:", e) + print(df) + raise e + + if schema is not None: + df_json["schema"]["schemaName"] = schema.name + + if metadata: + df_json = {**metadata, **df_json} + + return json.dumps(df_json) + + +def schema_to_json(dataframe_schema: pa.DataFrameSchema) -> Dict[str, Any]: + schema_specs = json.loads(dataframe_schema.to_json()) + + for check in schema_specs["checks"] or []: + if check not in ["check_less_than", "check_greater_than", "check_between"]: + continue + schema_specs["checks"][check] = { + k: v + for k, v in zip( + schema_specs["checks"][check].keys(), make_same_length_arguments(**schema_specs["checks"][check]) + ) + } + + for index in schema_specs["index"]: + index["required"] = True + + return schema_specs + + +def json_to_df( + input: Union[str, List[Dict[str, Any]], Dict[str, Dict[str, Any]], ExtractionResponse], + schema: Optional[pa.DataFrameSchema] = None, +) -> pd.DataFrame: + """ + Convert a JSON string to a DataFrame. If a schema is provided, the DataFrame will be validated against the schema. + + Args: + input: JSON string or a list of dictionaries + schema: DataFrameSchema + index_level_rename: Dictionary to rename index levels + + Returns: + pd.DataFrame + """ + if not input: + return pd.DataFrame() + + if schema is not None: + index_cols = [name for name in (schema.index.names if schema.index else []) if name] + required_columns = get_required_columns(schema) + else: + index_cols = [] + required_columns = [] + + try: + if isinstance(input, str): + df = parse_json_to_df(input) + elif isinstance(input, list): + df = pd.DataFrame(input) + elif isinstance(input, dict): + df = pd.DataFrame.from_dict(input, orient="index") + elif isinstance(input, ExtractionResponse): + df = pd.read_json(io.StringIO(json.dumps(input.dict())), orient="table") + else: + raise ValueError(f"Invalid input type: {type(input)}") + + df = df.map(lambda x: x.strip() if isinstance(x, str) else x) + + if schema is not None: + df = convert_to_schema_dtypes(df, schema) + + if index_cols and df.columns.intersection(index_cols).size: + df = df.set_index(index_cols) + except Exception as e: + _LOGGER.error(f"Failed to load DataFrame from JSON: {e}") + raise e + + return df + + +def parse_json_to_df(input: str) -> pd.DataFrame: + try: + df = pd.read_json(io.StringIO(input) if isinstance(input, str) else input, orient="table", convert_dates=False) + except: + input = re.sub(r'"type":"\w+"', '"type":"string"', input) + input = re.sub(r',\s*"extDtype":"[^"]+"', "", input) + df = pd.read_json(io.StringIO(input) if isinstance(input, str) else input, orient="table", convert_dates=False) + return df + + +def convert_to_schema_dtypes( + df: pd.DataFrame, schema: pa.DataFrameSchema, errors: Literal["raise", "ignore"] = "ignore" +) -> pd.DataFrame: + dtype_map = {"int64": "Int64", "int": "Int64", "int32": "Int64"} + + data_types = { + col: dtype_map.get(str(datatype), str(datatype)) for col, datatype in schema.dtypes.items() if col in df.columns + } + + df = df.astype(data_types, errors=errors) + return df + + +def is_json_table(json_string: str) -> bool: + if not isinstance(json_string, str) or not json_string.startswith("{") or not json_string.endswith("}"): + return False + + try: + json.dumps(json_string) + return True + except: + return False diff --git a/argilla-v1/src/extralit_v1/convert/markdown.py b/argilla-v1/src/extralit_v1/convert/markdown.py new file mode 100644 index 000000000..b95bf454f --- /dev/null +++ b/argilla-v1/src/extralit_v1/convert/markdown.py @@ -0,0 +1,15 @@ +import pandas as pd + + +def read_markdown_table_to_df(markdown_table: str) -> pd.DataFrame: + # Split the table into lines + lines = markdown_table.strip().split("\n") + + # Split each line into columns assuming | as the separator + rows = [line.split("|")[1:-1] for line in lines if line.startswith("|")] + + # Clean up any whitespace and markdown formatting from the cells + header = [cell.strip().strip("*_`") for cell in rows.pop(0)] # Assume first line is the header + data = [[cell.strip().strip("*_`") for cell in row] for row in rows if not all("-" in cell for cell in row)] + + return pd.DataFrame(data, columns=header) diff --git a/argilla-v1/src/extralit_v1/convert/pdf.py b/argilla-v1/src/extralit_v1/convert/pdf.py new file mode 100644 index 000000000..4cc612b1b --- /dev/null +++ b/argilla-v1/src/extralit_v1/convert/pdf.py @@ -0,0 +1,58 @@ +import os +from pathlib import Path +from typing import Optional + +from PIL import Image + +from extralit_v1.preprocessing.segment import Coordinates + + +def extract_image( + pdf_page: Image, coordinates: Coordinates, title: str, output_dir: Path, pad=20, redo=False +) -> Optional[str]: + if not coordinates or not coordinates.points: + return None + + # Create output directory if it doesn't exist + if output_dir and not os.path.exists(output_dir): + os.makedirs(output_dir, exist_ok=True) + + # Define output path + if not title.endswith(".png"): + title += ".png" + image_path = os.path.join(output_dir, title) + if os.path.exists(image_path) and not redo: + return image_path + + # Get page dimensions + page_width, page_height = pdf_page.size + + # Normalize coordinates + if coordinates.layout_width: + x_coords = [coordinates.points[i][0] * page_width / coordinates.layout_width for i in range(4)] + else: + x_coords = [coordinates.points[i][0] for i in range(4)] + + if coordinates.layout_height: + y_coords = [coordinates.points[i][1] * page_height / coordinates.layout_height for i in range(4)] + else: + y_coords = [coordinates.points[i][1] for i in range(4)] + + x1, x2 = min(x_coords), max(x_coords) + y1, y2 = min(y_coords), max(y_coords) + + # Add padding to the coordinates + x1 = max(0, x1 - pad) + y1 = max(0, y1 - pad) + x2 = min(page_width, x2 + pad) + y2 = min(page_height, y2 + pad) + + padded_coords = (x1, y1, x2, y2) + + # Crop the image + cropped_image = pdf_page.crop(padded_coords) + + # Save the image if it doesn't exist + cropped_image.save(image_path) + + return image_path diff --git a/argilla-v1/src/extralit_v1/convert/text.py b/argilla-v1/src/extralit_v1/convert/text.py new file mode 100644 index 000000000..966e738d8 --- /dev/null +++ b/argilla-v1/src/extralit_v1/convert/text.py @@ -0,0 +1,60 @@ +import logging +import re +from typing import List + +from rapidfuzz import fuzz + + +def remove_longest_repeated_subsequence(s: str, min_substring_len=1, min_repeats=5, verbose=True) -> str: + # Regex pattern to find all subsequences that repeat more than 5 times + pattern = r"(.+?)(?=\1{" + str(min_repeats) + ",})" + + # Find all matching subsequences + matches = re.findall(pattern, s, re.DOTALL) + + if not matches: + return s # Return the original string if no matches + + for subseq in set(matches): # Use set to remove duplicates + repeat_count = s.count(subseq) + + if len(subseq) >= min_substring_len and repeat_count > min_repeats: + if verbose: + logging.info(f"Removing repeating consecutive subsequence '{subseq}' repeated {repeat_count} times") + # Replace the repeating subsequence with an empty string + s = s.replace(subseq, "") + + return s + + +def find_longest_superstrings(strs: List[str], similarity_threshold: float = 90.0) -> List[str]: + superstrings = [] + + for string in sorted(strs, key=len, reverse=True): + # Check if the current string is a fuzzy substring of any existing superstring + is_substring = False + for superstring in superstrings: + if fuzz.partial_ratio(string, superstring) > similarity_threshold: + is_substring = True + break + + # If it's not a fuzzy substring of any superstring, check if it should absorb any superstrings + if not is_substring: + new_superstrings = [string] + for superstring in superstrings: + if fuzz.partial_ratio(superstring, string) <= similarity_threshold: + new_superstrings.append(superstring) + superstrings = new_superstrings + + return superstrings + + +def remove_markdown_from_string(s: str) -> str: + # Regular expression to identify markdown syntax + # Matches **bold**, __underline__, *italic*, and ~~strikethrough~~ + markdown_pattern = r"\*\*(.*?)\*\*|__(.*?)__|\*(.*?)\*|~~(.*?)~~|_(.*?)_" + return ( + re.sub(markdown_pattern, lambda m: m.group(1) or m.group(2) or m.group(3) or m.group(4) or m.group(5), s) + if isinstance(s, str) + else s + ) diff --git a/argilla/src/extralit/extraction/__init__.py b/argilla-v1/src/extralit_v1/extraction/__init__.py similarity index 100% rename from argilla/src/extralit/extraction/__init__.py rename to argilla-v1/src/extralit_v1/extraction/__init__.py diff --git a/argilla-v1/src/extralit_v1/extraction/chunking.py b/argilla-v1/src/extralit_v1/extraction/chunking.py new file mode 100644 index 000000000..7d0aaca52 --- /dev/null +++ b/argilla-v1/src/extralit_v1/extraction/chunking.py @@ -0,0 +1,151 @@ +import re +from os.path import join, exists +from typing import Optional, List, Tuple, Dict, Any, Iterable + +import argilla_v1 as rg +from extralit_v1.storage.files import FileHandler, StorageType +import pandas as pd +from llama_index.core.schema import Document + +from extralit_v1.convert.html_table import html_to_df +from extralit_v1.pipeline.ingest.segment import get_paper_tables +from extralit_v1.preprocessing.document import create_or_load_nougat_segments +from extralit_v1.preprocessing.segment import Segments + +INCLUDE_METADATA_KEYS = {"header": True, "footer": True, "level": True, "page_number": True, "type": True} +EXCLUDE_LLM_METADATA_KEYS = ["type", "page_number", "reference", "level"] + + +def create_nodes( + paper: pd.Series, + preprocessing_path="data/preprocessing/nougat/", + preprocessing_dataset: Optional[rg.FeedbackDataset] = None, + response_status=["submitted"], + exclude_llm_metadata_keys=EXCLUDE_LLM_METADATA_KEYS, + storage_type: StorageType = StorageType.FILE, + bucket_name: Optional[str] = None, + **nougat_kwargs, +) -> Tuple[List[Document], List[Document]]: + """ + Create or load the documents from the paper segments. + + Args: + paper: pd.Series, required + A paper from the dataset. + preprocessing_dataset: rg.FeedbackDataset, default=None + Manually annotated preprocessing dataset. If given, the TableSegments will be loaded from this dataset. + response_status: List[str], default=['submitted'] + The response status of the records to consider. + preprocessing_path: str, default='data/preprocessing/nougat/' + Path to the preprocessed data. + ignore_metadata: set, default={'text', 'type', 'level', 'children', 'coordinates', 'source', 'html', 'original', + 'probability', 'image'} + Metadata to exclude from the documents. + nougat_kwargs: dict + Additional arguments for the NougatOCR. + """ + assert len(paper.name) > 0, f"Paper name must be given, given {paper.name}" + reference = paper.name + + file_handler = FileHandler(preprocessing_path, storage_type, bucket_name) + + if preprocessing_dataset is not None: + # Load the segments from the manually annotated preprocessing dataset + text_segments, _, _ = create_or_load_nougat_segments(paper, file_handler=file_handler, **nougat_kwargs) + table_segments = get_paper_tables(paper, preprocessing_dataset, response_status=response_status) + else: + # Load the segments from `nougat` preprocessed data + texts_path = join(preprocessing_path, reference, "texts.json") + tables_path = join(preprocessing_path, reference, "tables.json") + text_segments = ( + Segments.parse_raw(file_handler.read_text(texts_path)) if file_handler.exists(texts_path) else Segments() + ) + table_segments = ( + Segments.parse_raw(file_handler.read_text(tables_path)) if file_handler.exists(tables_path) else Segments() + ) + + extra_metadata = {"reference": reference} + text_nodes = create_text_nodes( + text_segments, extra_metadata=extra_metadata, exclude_llm_metadata_keys=exclude_llm_metadata_keys + ) + + table_nodes = create_table_nodes( + table_segments, extra_metadata=extra_metadata, exclude_llm_metadata_keys=exclude_llm_metadata_keys + ) + + return text_nodes, table_nodes + + +def create_text_nodes( + text_segments: Segments, extra_metadata: Optional[Dict[str, Any]], exclude_llm_metadata_keys: Iterable +) -> List[Document]: + text_documents = [] + for i, segment in enumerate(text_segments.items): + if i == 0: + continue + elif i == 1: + title = text_segments[0].text + segment.text = title + segment.text + + if "references" in segment.header.lower(): + continue + elif ( + len(segment.text) < 1000 + and "conflicts of interest" in segment.header.lower() + or "acknowledgements" in segment.header.lower() + or re.search(r"author.*contributions", segment.header.lower()) + ): + continue + elif not segment.text: + continue + + metadata = segment.dict(include=INCLUDE_METADATA_KEYS) + if extra_metadata: + metadata.update(extra_metadata) + + doc = Document( + id_=segment.id, + text=segment.text.strip(), + type=segment.type, + metadata=metadata, + relationships=segment.relationships, + excluded_embed_metadata_keys=exclude_llm_metadata_keys, + excluded_llm_metadata_keys=exclude_llm_metadata_keys, + ) + text_documents.append(doc) + + return text_documents + + +def create_table_nodes( + table_segments: Segments, extra_metadata: Optional[Dict[str, Any]], exclude_llm_metadata_keys: Iterable +) -> List[Document]: + table_documents = [] + for segment in table_segments.items: + if not segment.html: + continue + + try: + df = html_to_df(segment.html, convert_spanning_rows=True) + except Exception as e: + print(f"Failed to convert HTML to DataFrame: {e}") + continue + + assert df.columns.nlevels == 1, f"MultiIndex columns are not supported, given {df.columns}" + metadata = segment.dict(include=INCLUDE_METADATA_KEYS) + if extra_metadata: + metadata.update(extra_metadata) + metadata["columns"] = df.columns.tolist() + + doc = Document( + id_=segment.id, + text=df.to_json(orient="index"), + type=segment.type, + metadata=metadata, + relationships=segment.relationships, + excluded_embed_metadata_keys=exclude_llm_metadata_keys, + excluded_llm_metadata_keys=exclude_llm_metadata_keys, + ) + table_documents.append(doc) + + return table_documents diff --git a/argilla-v1/src/extralit_v1/extraction/extraction.py b/argilla-v1/src/extralit_v1/extraction/extraction.py new file mode 100644 index 000000000..d6a935283 --- /dev/null +++ b/argilla-v1/src/extralit_v1/extraction/extraction.py @@ -0,0 +1,240 @@ +import logging +import os +import warnings +from os.path import join, exists +from typing import Tuple, Dict, Optional, List, Union + +import pandas as pd +import pandera as pa +from langfuse.model import TextPromptClient +from llama_index.core import VectorStoreIndex, PromptTemplate, Response +from llama_index.core.vector_stores import ( + MetadataFilter, + MetadataFilters, + FilterOperator, + FilterCondition, +) +from pydantic.v1 import BaseModel + +from extralit_v1.extraction.models.paper import PaperExtraction +from extralit_v1.extraction.models.response import ResponseResult, ResponseResults +from extralit_v1.extraction.models.schema import SchemaStructure +from extralit_v1.extraction.prompts import ( + create_extraction_prompt, + create_completion_prompt, + DEFAULT_EXTRACTION_PROMPT_TMPL, +) +from extralit_v1.extraction.schema import get_extraction_schema_model +from extralit_v1.extraction.utils import convert_response_to_dataframe, generate_reference_columns +from extralit_v1.extraction.vector_index import load_index +from extralit_v1.schema.references.assign import assign_unique_index, get_prefix + +_LOGGER = logging.getLogger(__name__) + + +def query_rag_index( + prompt: str, + index: VectorStoreIndex, + output_cls=BaseModel, + similarity_top_k=20, + filters: Optional[MetadataFilters] = None, + response_mode="compact", + text_qa_template=DEFAULT_EXTRACTION_PROMPT_TMPL, + **kwargs, +) -> Response: + warnings.filterwarnings("ignore", module="pydantic") + + query_engine = index.as_query_engine( + output_cls=output_cls, + response_mode=response_mode, + similarity_top_k=similarity_top_k, + filters=filters, + text_qa_template=text_qa_template, + **kwargs, + ) + + obs_response = query_engine.query(prompt) + + return obs_response + + +def extract_schema( + schema: pa.DataFrameSchema, + extractions: PaperExtraction, + index: VectorStoreIndex, + include_fields: Optional[List[str]] = None, + headers: Optional[List[str]] = None, + types: Optional[List[str]] = None, + similarity_top_k=20, + system_prompt: Optional[Union[PromptTemplate, TextPromptClient]] = DEFAULT_EXTRACTION_PROMPT_TMPL, + user_prompt: Optional[str] = None, + verbose=False, + **kwargs, +) -> Tuple[pd.DataFrame, ResponseResult]: + """ + Extract a complete table based on schema using the RAG on a paper. + Args: + schema (pa.DataFrameSchema): The schema to extract. + extractions (PaperExtraction): The extractions from the paper. + index (VectorStoreIndex): The index to use for the extraction. + similarity_top_k (int): The number of similar documents to retrieve. Defaults to 20. + include_fields (Optional[List[str]]): A list of column names to include in the Pydantic model. Defaults to None. + headers (Optional[List[str]]): The headers to filter the documents by. Defaults to None. + system_prompt (PromptTemplate): The text QA template to use. Defaults to the default text QA template. + verbose (Optional[int]): The verbosity level. Defaults to None. + **kwargs (Dict): Additional keyword arguments to pass to the `query_rag_llm` and `as_query_engine` function. + text_qa_template (PromptTemplate): The text QA template to use. Defaults to the default text QA template. + vector_store_query_mode (str): The vector store query mode. Defaults to "hybrid". + + Returns: + Tuple[pd.DataFrame, ResponseResult]: The extracted DataFrame and the ResponseResult. + """ + + if schema.name in extractions.extractions: + prompt = create_completion_prompt(schema, extractions, include_fields=include_fields, extra_prompt=user_prompt) + else: + prompt = create_extraction_prompt( + schema, + extractions, + ) + + output_cls = get_extraction_schema_model( + schema, + include_fields=include_fields, + exclude_fields=["reference"], + top_class=schema.name + "s", + lower_class=schema.name, + validate_assignment=False, + ) + + filters = MetadataFilters( + filters=[MetadataFilter(key="reference", value=extractions.reference, operator=FilterOperator.EQ)], + condition=FilterCondition.AND, + ) + if headers: + filters.filters.append(MetadataFilter(key="header", value=headers, operator=FilterOperator.IN)) + if types: + filters.filters.append(MetadataFilter(key="type", value=types, operator=FilterOperator.IN)) + + if verbose: + _LOGGER.info(f"Filters {filters.__repr__()}") + + if isinstance(system_prompt, TextPromptClient): + system_prompt = PromptTemplate(system_prompt.get_langchain_prompt()) + elif isinstance(system_prompt, PromptTemplate): + pass + else: + _LOGGER.warning( + f"Invalid system_prompt type: {type(system_prompt)}, reverting to " f"DATA_EXTRACTION_SYSTEM_PROMPT_TMPL." + ) + system_prompt = DEFAULT_EXTRACTION_PROMPT_TMPL + + response = query_rag_index( + prompt, + index=index, + output_cls=output_cls, + similarity_top_k=similarity_top_k, + filters=filters, + text_qa_template=system_prompt, + response_mode="compact", + **kwargs, + ) + + df = convert_response_to_dataframe(response) + df = generate_reference_columns(df, schema) + try: + response = ResponseResult(**response.__dict__) + except Exception as e: + _LOGGER.error(f"Failed to create ResponseResult: {e}") + response = ResponseResult() + + return df, response + + +def extract_paper( + paper: pd.Series, + schema_structure: SchemaStructure, + index: VectorStoreIndex = None, + prompt: str = "default", + llm_models: Union[List[str], str] = ["gpt-4o", "gpt-4-turbo"], + embed_model: str = "text-embedding-ada-002", + index_kwargs: Dict = None, + interim_path="data/interim/", + load_only=False, + verbose: int = 0, +) -> Tuple[PaperExtraction, ResponseResults]: + reference = paper.name + if isinstance(llm_models, str): + llm_models = [llm_models] + + ### Load interim results ### + interim_save_dir = join(interim_path, llm_models[0], reference) + if load_only and exists(interim_save_dir): + if not exists(interim_save_dir): + raise FileNotFoundError(f"Interim save directory does not exist: {interim_save_dir}") + with open(join(interim_save_dir, "responses.json"), "r") as file: + responses = ResponseResults.parse_raw(file.read()) + + extractions = PaperExtraction( + extractions={k: v.response.to_df() for k, v in responses.items.items()}, + schemas=schema_structure, + reference=reference, + ) + return extractions, responses + + ### Create or load the index ### + if index is None: + index = load_index(paper, llm_model=llm_models[0], embed_model=embed_model, **(index_kwargs or {})) + assert ( + index.service_context.llm.model == llm_models[0] + ), f"LLM model mismatch: {index.service_context.llm.model} != {llm_models[0]}" + + extractions = PaperExtraction(extractions={}, schemas=schema_structure, reference=reference) + responses = ResponseResults(items={}, docs_metadata={id: doc.metadata for id, doc in index.docstore.docs.items()}) + + ### Extract entities ### + for schema_name in extractions.schemas.ordering: + schema = extractions.schemas[schema_name] + + df = extract_schema_with_fallback( + schema=schema, extractions=extractions, index=index, responses=responses, models=llm_models, verbose=verbose + ) + + if schema.index and schema.index.name: + df = assign_unique_index(df, schema, index_name=schema.index.name, prefix=get_prefix(schema), n_digits=2) + df = df.drop_duplicates() + + extractions.extractions[schema_name] = df + + ### Save interim results ### + try: + os.makedirs(interim_save_dir, exist_ok=True) + with open(join(interim_save_dir, "responses.json"), "w") as file: + file.write(responses.json()) + except Exception as e: + _LOGGER.error(f"Interim save responses: {e}") + + return extractions, responses + + +def extract_schema_with_fallback( + schema: pa.DataFrameSchema, + extractions: PaperExtraction, + index: VectorStoreIndex, + responses: ResponseResults, + models: List[str], + verbose: Optional[int] = 0, + **kwargs, +) -> pd.DataFrame: + for model in models: + try: + index.service_context.llm.model = model + df, responses[schema.name] = extract_schema(schema=schema, extractions=extractions, index=index, **kwargs) + return df + except Exception as e: + _LOGGER.log(logging.WARNING, f"Error {schema.name} ({model}): {e}") + + if verbose >= 2: + raise e + + return pd.DataFrame() diff --git a/argilla-v1/src/extralit_v1/extraction/models/__init__.py b/argilla-v1/src/extralit_v1/extraction/models/__init__.py new file mode 100644 index 000000000..e2b00278e --- /dev/null +++ b/argilla-v1/src/extralit_v1/extraction/models/__init__.py @@ -0,0 +1,9 @@ +from extralit_v1.schema.checks import register_check_methods + +register_check_methods() + +from .schema import SchemaStructure, DEFAULT_SCHEMA_S3_PATH +from .response import ResponseResults +from .paper import PaperExtraction + +__all__ = ["SchemaStructure", "ResponseResults", "PaperExtraction", "DEFAULT_SCHEMA_S3_PATH"] diff --git a/argilla-v1/src/extralit_v1/extraction/models/paper.py b/argilla-v1/src/extralit_v1/extraction/models/paper.py new file mode 100644 index 000000000..34923350f --- /dev/null +++ b/argilla-v1/src/extralit_v1/extraction/models/paper.py @@ -0,0 +1,134 @@ +import itertools +import logging +from datetime import datetime +from typing import Dict, Iterator, Tuple, Optional, Union +from uuid import UUID + +import pandas as pd +import pandera as pa +from pandera.api.base.model import MetaModel +from pydantic.v1 import BaseModel, Field + +from extralit_v1.extraction.models.schema import SchemaStructure + +_LOGGER = logging.getLogger(__name__) + + +class PaperExtraction(BaseModel): + reference: str + extractions: Dict[str, pd.DataFrame] = Field(default_factory=dict) + schemas: SchemaStructure = Field(..., description="The schema structure of the extraction.") + durations: Dict[str, Optional[float]] = Field(default_factory=dict) + updated_at: Dict[str, Optional[datetime]] = Field(default_factory=dict) + inserted_at: Dict[str, Optional[datetime]] = Field(default_factory=dict) + user_id: Dict[str, Optional[UUID]] = Field(default_factory=dict) + + class Config: + arbitrary_types_allowed = True + + def get_joined_data(self, schema_name: str, drop_joined_index=True) -> pd.DataFrame: + """ + Join the extraction DataFrame with the dependent DataFrames based on the schema index. + Args: + schema_name: The schema name to join. + drop_joined_index: Drop the joined index column. + index_name: The index name to join on. + + Returns: + + """ + schema = self.schemas[schema_name] + df = self[schema_name].copy() + + # For each '_ref' key, find the matching DataFrame with the same DataFrameModel prefix + for ref_column in self.schemas.index_names(schema_name): + dep_schema_name = self.schemas.get_ref_schema(ref_column).name + if ref_column not in df.index.names and ref_column not in df.columns: + # Skip if the DataFrame is already joined + _LOGGER.info(f"Skipping join on {ref_column} as it is already joined. \n{df.index.names}\n{df.columns}") + continue + + dependent_df = next( + ( + value.copy() + for key, value in self.extractions.items() + if str(key).lower() == dep_schema_name.lower() and value.size > 0 + ), + None, + ) + if dependent_df is None: + continue + + if self.user_id.get(dep_schema_name) and self.user_id.get(dep_schema_name) == self.user_id.get(schema_name): + print(f"Skipping join on {dep_schema_name} as it is the same user.") + _LOGGER.info(f"Skipping join on {dep_schema_name} as it is the same user.") + + try: + dependent_df = dependent_df.rename_axis(index={"reference": ref_column}) + df = df.join(dependent_df, how="left", rsuffix="_joined") + df = overwrite_joined_columns(df, rsuffix="_joined", prepend=True) + if drop_joined_index and ref_column in df.index.names: + df = df.reset_index(level=ref_column, drop=True) + except NotImplementedError as e: + _LOGGER.info(f"{dep_schema_name}-{schema.name} extraction table is already joined.") + except Exception as e: + _LOGGER.error(f"Failed to join `{dep_schema_name}` to {schema.name}: {e}") + raise e + + return df + + @property + def size(self): + return sum(df.size for schema_name, df in self.extractions.items()) + + def __getitem__(self, item: str) -> pd.DataFrame: + if isinstance(item, pa.DataFrameSchema): + return self.extractions[item.name] + elif isinstance(item, MetaModel): + return self.extractions[str(item)] + return self.extractions[item] + + def __contains__(self, item: Union[pa.DataFrameModel, str]) -> bool: + if isinstance(item, pa.DataFrameSchema): + return item.name in self.extractions + elif isinstance(item, MetaModel): + return str(item) in self.extractions + return item in self.extractions + + def __getattr__(self, item: str) -> pd.DataFrame: + if self.__contains__(item): + return self.__getitem__(item) + else: + raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{item}'") + + def __dir__(self) -> Iterator[str]: + extraction_keys = [str(key) for key in self.extractions.keys()] + return itertools.chain(super().__dir__(), extraction_keys) + + def __setitem__(self, key: str, value: pd.DataFrame) -> None: + assert isinstance(key, str), f"Expected str, got {type(key)}" + self.extractions[key] = value + + def items(self) -> Iterator[Tuple[str, pd.DataFrame]]: + return self.extractions.items() + + def __repr_args__(self): + args = [(k, v.dropna(axis=1, how="all").shape) for k, v in self.extractions.items() if v.size] + return args + + +def overwrite_joined_columns(df: pd.DataFrame, rsuffix="_joined", prepend=True) -> pd.DataFrame: + # Overwrite the original column with the '_joined' column + suffix_columns = [col for col in df.columns if col.endswith(rsuffix)] + joined_columns = [col.rsplit(rsuffix, 1)[0] for col in suffix_columns] + + for joined_col in suffix_columns: + original_col = joined_col.rsplit(rsuffix, 1)[0] + df[original_col] = df[joined_col] + df = df.drop(columns=joined_col) + + if prepend: + column_reorder = [*joined_columns, *df.columns.difference(joined_columns)] + df = df.reindex(columns=column_reorder) + + return df diff --git a/argilla-v1/src/extralit_v1/extraction/models/response.py b/argilla-v1/src/extralit_v1/extraction/models/response.py new file mode 100644 index 000000000..f00ea1418 --- /dev/null +++ b/argilla-v1/src/extralit_v1/extraction/models/response.py @@ -0,0 +1,127 @@ +from typing import Union, List, Dict, Any, Optional +from pydantic.v1 import BaseModel, validator, Field +from typing_extensions import TypedDict + +import pandas as pd +import tiktoken +from llama_index.core import VectorStoreIndex +from llama_index.core.schema import TextNode + +from extralit_v1.extraction.staging import to_df + + +class BaseModelForLlamaIndexResponse(BaseModel): + items: List[Dict[str, Any]] + + def to_df(self, *args, **kwargs) -> pd.DataFrame: + return to_df(self, *args, **kwargs) + + +class SourceNode(TypedDict): + node: Union[TextNode, Dict[str, Union[str, Dict[str, str]]]] + score: Optional[float] + + class Config: + arbitrary_types_allowed = True + + +class ResponseResult(BaseModel): + response: Optional[BaseModelForLlamaIndexResponse] + source_nodes: Optional[List[SourceNode]] + metadata: Optional[Dict[str, Dict[str, Any]]] + + @validator("source_nodes", pre=True) + def parse_source_nodes(cls, v): + return [dict(node) for node in v] + + def get_nodes_info( + self, + count_tokens=False, + tokenizer_model="text-embedding-3-small", + max_char_len=25, + header_doc_id_map: Optional[Dict[str, str]] = None, + ) -> pd.DataFrame: + if count_tokens: + tiktoken_encoder = tiktoken.get_encoding(tokenizer_model) + + nodes_dict = {} + for i, source_node in enumerate(self.source_nodes, start=1): + node, score = source_node["node"], source_node["score"] + text, metadata = node["text"], node["metadata"] + tokens = len(tiktoken_encoder.encode(text)) if count_tokens and text else None + + if header_doc_id_map: + header = metadata.get("header") + if header in header_doc_id_map: + metadata["doc_id"] = header_doc_id_map[header] + + nodes_dict[i] = { + "doc_id": node["id_"], + "relevance": score, + **metadata, + "text": text[:max_char_len].replace("\n", "") + ("..." if len(text) > max_char_len else ""), + } + if count_tokens: + nodes_dict[i]["n_tokens"] = tokens + + context_df = pd.DataFrame.from_dict(nodes_dict, orient="index") + + return context_df + + class Config: + arbitrary_types_allowed = True + + +class ResponseResults(BaseModel): + items: Dict[str, ResponseResult] = Field(default_factory=dict) + docs_metadata: Dict[str, Dict[str, Any]] = Field( + default_factory=dict, description="Metadata for all nodes in the RAG index" + ) + + def init_docs_from_index(self, index: VectorStoreIndex, reference: str): + if type(index.vector_store).__name__ == "WeaviateVectorStore": + from extralit_v1.extraction.query import get_nodes_metadata + + weaviate_client = index.vector_store.client + + results = get_nodes_metadata( + weaviate_client, + index_name=index.vector_store.index_name, + properties=["reference", "header", "doc_id", "page_number"], + filters={"reference": reference}, + ) + docs_metadata = {result["doc_id"]: result for result in results} + else: + docs_metadata = {id: doc.metadata for id, doc in index.docstore.docs.items()} + + self.docs_metadata = docs_metadata + + def get_nodes_info(self, schema_name=None, **kwargs) -> pd.DataFrame: + if schema_name: + return self.items[schema_name].get_nodes_info(**kwargs) + elif not schema_name and not self.docs_metadata: + raise ValueError("No metadata available to extract nodes from, run `init_docs_from_index` first.") + + df = pd.DataFrame.from_dict(self.docs_metadata, orient="index") + df.index.name = "doc_id" + return df.reset_index(drop=df.index.name in df.columns) + + def get_ranked_nodes(self, schema_name: str, include_all_nodes=True) -> pd.DataFrame: + header_doc_id_map = {doc["header"]: doc["doc_id"] for doc in self.docs_metadata.values()} + selected_nodes = self.items[schema_name].get_nodes_info(count_tokens=False, header_doc_id_map=header_doc_id_map) + if include_all_nodes: + all_nodes = self.get_nodes_info() + ranked_nodes = pd.concat([selected_nodes, all_nodes]).drop_duplicates(subset=["header"]) + else: + ranked_nodes = selected_nodes + + return ranked_nodes + + def __getitem__(self, key): + return self.items[key] + + def __setitem__(self, key, value): + self.items[key] = value + + class Config: + arbitrary_types_allowed = True diff --git a/argilla-v1/src/extralit_v1/extraction/models/schema.py b/argilla-v1/src/extralit_v1/extraction/models/schema.py new file mode 100644 index 000000000..02cd2b40b --- /dev/null +++ b/argilla-v1/src/extralit_v1/extraction/models/schema.py @@ -0,0 +1,348 @@ +import logging +import os +from collections import deque +from glob import glob +from io import BytesIO +from pathlib import Path +from typing import List, Optional, Union, Dict + +import pandera as pa +import argilla_v1 as rg +from minio import Minio +from pandera.api.base.model import MetaModel +from pandera.io import from_json, from_yaml +from pydantic.v1 import BaseModel, Field, validator + +DEFAULT_SCHEMA_S3_PATH = "schemas/" + +_LOGGER = logging.getLogger(__name__) + + +def topological_sort( + schema_name: str, visited: Dict[str, int], stack: deque, dependencies: Dict[str, List[str]] +) -> None: + visited[schema_name] = 1 # Gray + + for i in dependencies.get(schema_name, []): + if visited[i] == 1: # If the node is gray, it means we have a cycle + raise ValueError(f"Circular dependency detected: {schema_name} depends on {i} and vice versa") + if visited[i] == 0: # If the node is white, visit it + topological_sort(i, visited, stack, dependencies) + + visited[schema_name] = 2 # Black + stack.appendleft(schema_name) + + +class SchemaStructure(BaseModel): + """ + A class representing the structure of a schema. + + Usage: + ```python + from pandera import DataFrameSchema + from extralit_v1.extraction.models.schema import SchemaStructure + + schema_structure = SchemaStructure( + schemas=[ + DataFrameSchema( + columns={ + "name": pa.Column(pa.String), + "age": pa.Column(pa.Int) + } + ) + ] + ) + ``` + """ + + schemas: List[pa.DataFrameSchema] = Field(default_factory=list, description="A list of all the extraction schemas.") + singleton_schema: Optional[pa.DataFrameSchema] = Field( + None, repr=True, description="A singleton schema that exists in `schemas` list." + ) + + def __init__(self, **data): + super().__init__(**data) + + # Ensure singleton_schema is in schemas list + if self.singleton_schema and self.singleton_schema not in self.schemas: + self.schemas.append(self.singleton_schema) + + for schema in self.schemas: + is_singleton_schema = any( + check.name == "singleton" and check.statistics.get("enabled", True) for check in schema.checks + ) + + if is_singleton_schema and not self.singleton_schema: + self.singleton_schema = schema + elif is_singleton_schema and self.singleton_schema and schema != self.singleton_schema: + raise ValueError("Only one singleton schema is allowed in the schema structure") + + @validator("schemas", pre=True, each_item=True) + def parse_schema(cls, v: Union[pa.DataFrameModel, pa.DataFrameSchema]): + return v.to_schema() if hasattr(v, "to_schema") else v + + @validator("singleton_schema", pre=True) + def parse_singleton_schema(cls, v: Union[pa.DataFrameModel, pa.DataFrameSchema]): + schema: pa.DataFrameSchema = v.to_schema() if hasattr(v, "to_schema") else v + assert all( + key.islower() for key in schema.columns.keys() + ), f"All keys in {schema.name} schema must be lowercased" + return schema + + @classmethod + def from_dir(cls, dir_path: Path, exclude: Optional[List[str]] = None): + """ + Load a SchemaStructure from a directory containing pandera DataFrameSchema .json files. + Args: + dir_path: A directory path containing pandera DataFrameSchema .json files. + exclude: A list of schema names to exclude from the schema structure. + + Returns: + SchemaStructure + """ + schemas = {} + if os.path.isdir(dir_path): + schema_paths = sorted(glob(os.path.join(dir_path, "*.json")), key=lambda x: not x.endswith(".json")) + else: + schema_paths = sorted(glob(dir_path), key=lambda x: not x.endswith(".json")) + + for filepath in schema_paths: + try: + if filepath.endswith(".json"): + schema = from_json(filepath) + elif filepath.endswith(".yaml") or filepath.endswith(".yml"): + schema = from_yaml(filepath) + else: + continue + + if schema.name in schemas or (exclude and schema.name in exclude): + continue + + schemas[schema.name] = schema + except Exception as e: + _LOGGER.warning(f"Ignoring failed schema loading from '{filepath}': \n{e}") + + return cls(schemas=list(schemas.values())) + + @classmethod + def from_workspace(cls, workspace: "rg.Workspace", prefix: str = DEFAULT_SCHEMA_S3_PATH, exclude: List[str] = []): + return workspace.get_schemas(prefix=prefix, exclude=exclude) + + @classmethod + def from_s3( + cls, + workspace_name: str, + minio_client: Minio, + prefix: str = DEFAULT_SCHEMA_S3_PATH, + exclude: List[str] = [], + verbose: bool = True, + ): + """ + Load a SchemaStructure from a Minio bucket containing pandera DataFrameSchema .json files. + + Args: + workspace: The workspace name. + minio_client: The Minio client. + prefix: The prefix to search for schemas. + exclude: A list of schema names to exclude from the schema structure. + verbose: Whether to log verbose output. + + Returns: + SchemaStructure + """ + schemas = {} + objects = minio_client.list_objects(workspace_name, prefix=prefix, include_version=False) + + # Sort the objects by file extension + objects = sorted( + objects, key=lambda obj: (os.path.splitext(obj.object_name)[1] != "", os.path.splitext(obj.object_name)[1]) + ) + + for obj in objects: + filepath = obj.object_name + file_extension = os.path.splitext(filepath)[1] + + try: + data = minio_client.get_object(workspace_name, filepath) + file_data = BytesIO(data.read()) + + if not file_extension or file_extension == ".json": + schema = from_json(file_data) + elif file_extension in [".yaml", ".yml"]: + schema = from_yaml(file_data) + else: + continue + + if schema.name in schemas or schema.name in exclude: + continue + + _LOGGER.info(f"Loaded {schema.name} from {filepath}", exc_info=1) + schemas[schema.name] = schema + except Exception as e: + _LOGGER.warning(f"Ignoring failed schema loading from '{filepath}': \n{e}") + + return cls(schemas=list(schemas.values())) + + def to_s3(self, workspace_name: str, minio_client: Minio, prefix: str = "schemas/", delete_excluded: bool = False): + """ + This method is used to upload the schemas to an S3 bucket and optionally delete the excluded schemas. + + Args: + workspace (str): The workspace name. + minio_client (Minio): The Minio client. + prefix (str, optional): The prefix to use for the schemas in the S3 bucket. Default is 'schemas/'. + delete_excluded (bool, optional): A flag to determine whether to delete the excluded schemas or not. Default is True. + + Returns: + None + """ + + for schema in self.schemas: + # Serialize the schema to a JSON string + schema_json = schema.to_json() + + # Create a BytesIO object from the JSON string + schema_bytes = BytesIO(schema_json.encode()) + + # Define the object name + object_name = os.path.join(prefix, schema.name) + + # Upload the BytesIO object to the S3 bucket + minio_client.put_object( + bucket_name=workspace_name, + object_name=object_name, + data=schema_bytes, + length=schema_bytes.getbuffer().nbytes, + content_type="application/json", + ) + + if delete_excluded: + objects = minio_client.list_objects(workspace_name, prefix=prefix, include_version=False) + bucket_schema_paths = [os.path.splitext(obj.object_name)[0] for obj in objects] + self_schema_paths = [os.path.join(prefix, schema.name) for schema in self.schemas] + schemas_to_delete = set(bucket_schema_paths) - set(self_schema_paths) + print("Deleting schemas:", schemas_to_delete) + for schema_path in schemas_to_delete: + minio_client.remove_object(workspace_name, schema_path) + + def get_joined_schema(self, schema_name: str): + combined_columns = {} + combined_checks = [] + + # Iterate over the provided schema and its dependent schemas + dependent_schemas: List[pa.DataFrameSchema] = [ + self.__getitem__(sn) for sn in self.upstream_dependencies.get(schema_name) + ] + + for schema in [self.__getitem__(schema_name)] + dependent_schemas: + for column_name, column_schema in schema.columns.items(): + if column_name not in combined_columns: + combined_columns[column_name] = column_schema + + combined_checks.extend(schema.checks) + + joined_schema = pa.DataFrameSchema(columns=combined_columns, checks=combined_checks, name=schema_name) + return joined_schema + + @property + def downstream_dependencies(self) -> Dict[str, List[str]]: + dependents = {} + for schema in self.schemas: + dependents[schema.name] = [] + schema_index_names = self.index_names(schema) + + for dep in self.schemas: + if not dep.index or schema == dep: + continue + dep_index_names = self.index_names(dep) + if f"{schema.name}_ref".lower() in dep_index_names: + dependents[schema.name].append(dep.name) + + if ( + schema.index + and f"{schema.name}_ID" in dep_index_names + and f"{schema.name}_ID" in schema_index_names + ): + dependents[schema.name].append(dep.name) + return dependents + + @property + def upstream_dependencies(self) -> Dict[str, List[str]]: + dependencies = {} + for schema in self.schemas: + dependencies[schema.name] = [] + schema_index_names = self.index_names(schema) + + for dep in self.schemas: + if not schema.index or schema == dep: + continue + dep_index_names = self.index_names(dep) + if f"{dep.name}_ref".lower() in schema_index_names: + dependencies[schema.name].append(dep.name) + + if dep.index and f"{dep.name}_ID" in schema_index_names and f"{dep.name}_ID" in dep_index_names: + dependencies[schema.name].append(dep.name) + return dependencies + + def index_names(self, schema: Union[str, pa.DataFrameSchema]) -> List[str]: + schema = self.__getitem__(schema) if isinstance(schema, str) else schema + if not schema.index: + return [] + index_names = list(schema.index.names or [schema.index.name]) + index_names = [name for name in index_names if name] + return index_names + + def get_ref_schema(self, ref_column: str) -> pa.DataFrameSchema: + if not ref_column.endswith("_ref") and not ref_column.endswith("_ID"): + raise ValueError(f"Foreign key '{ref_column}' must contain '_ref' or '_ID' suffix") + schema_name = ref_column.rsplit("_ref", 1)[0].rsplit("_ID", 1)[0] + return self.__getitem__(schema_name) + + def columns(self, schema: str) -> List[str]: + columns = list(self.__getitem__(schema).columns) + return columns + + @property + def ordering(self) -> List[str]: + visited = {schema.name: 0 for schema in self.schemas} + stack = deque() + + # Ensure singleton_schema is ordered first + if self.singleton_schema: + stack.append(self.singleton_schema.name) + visited[self.singleton_schema.name] = 2 # Mark as visited (black) + + for schema in self.schemas: + if visited[schema.name] == 0: + # If the node is white, visit it + topological_sort(schema.name, visited, stack, self.downstream_dependencies) + + # Ensure singleton_schema is at the beginning of the list + ordered_list = list(stack) + if self.singleton_schema and ordered_list[0] != self.singleton_schema.name: + ordered_list.remove(self.singleton_schema.name) + ordered_list.insert(0, self.singleton_schema.name) + + return ordered_list + + def __iter__(self): + return iter(self.ordering) + + def __getitem__(self, item: str): + if isinstance(item, pa.DataFrameSchema): + item = item.name + elif isinstance(item, MetaModel): + item = str(item) + + for schema in self.schemas: + if schema.name.lower() == item.lower(): + return schema + raise KeyError(f"No schema found for '{item}'") + + class Config: + arbitrary_types_allowed = True + + def __repr__(self): + schema_names = [schema.name for schema in self.schemas] + singleton_schema_name = self.singleton_schema.name if self.singleton_schema else None + return f"SchemaStructure(schemas={schema_names}, singleton_schema={singleton_schema_name})" diff --git a/argilla-v1/src/extralit_v1/extraction/prompts.py b/argilla-v1/src/extralit_v1/extraction/prompts.py new file mode 100644 index 000000000..eb55422ce --- /dev/null +++ b/argilla-v1/src/extralit_v1/extraction/prompts.py @@ -0,0 +1,142 @@ +import json +from typing import List, Optional + +import pandera as pa +from llama_index.core import PromptTemplate, ChatPromptTemplate +from llama_index.core.base.llms.types import MessageRole, ChatMessage +from llama_index.core.prompts import chat_prompts + +from extralit_v1.extraction.models import PaperExtraction +from extralit_v1.extraction.schema import get_extraction_schema_model, drop_type_def_from_schema_json +from extralit_v1.extraction.utils import filter_unique_columns, stringify_lists + +FIGURE_TABLE_EXT_PROMPT_TMPL = PromptTemplate( + """Given the figure from a research paper, please extract only the variables and observations names of the figure/chart as columns header and rows index in an HTML table, but do not extract any numerical data values. +Figure information is below. +--------------------- +{header_str} +--------------------- +Answer:""" +) + + +DATA_EXTRACTION_SYSTEM_PROMPT_TMPL = PromptTemplate( + """Your ability to extract and summarize this context accurately is essential for effective analysis. Pay close attention to the context's language, structure, and any cross-references to ensure a comprehensive and precise extraction of information. Do not use prior knowledge or information from outside the context to answer the questions. Only use the information provided in the context to answer the questions.\n""" + "Context information is below.\n" + "---------------------\n" + "{context_str}\n" + "---------------------\n" + "Query: {query_str}\n" + "Answer: " +) + +# DEFAULT_EXTRACTION_PROMPT_TMPL = PromptTemplate(default_prompts.DEFAULT_TEXT_QA_PROMPT_TMPL) +DEFAULT_EXTRACTION_PROMPT_TMPL = DATA_EXTRACTION_SYSTEM_PROMPT_TMPL + + +def create_extraction_prompt(schema: pa.DataFrameSchema, extractions: PaperExtraction, filter_unique_cols=False) -> str: + prompt = ( + f"Your task is to extract data from a research paper.\n" + f"The `{schema.name}` details can be split across the provided context. Respond with details by looking at the whole context always.\n" + f"If you don't find the information in the given context or you are not sure, " + f"omit the key-value in your JSON response. " + ) + schema_structure = extractions.schemas + dependencies = schema_structure.upstream_dependencies[schema.name] + if dependencies: + prompt += ( + f"The `{schema.name}` data you're extracting is dependent on the provided " + f"`{stringify_lists(dependencies, conjunction='and')}` tables containing entities which you need to reference. " + f"There can be multiple `{schema.name}` data entries for each unique combination of these references." + f"Here are the data already extracted from the paper:\n\n" + ) + + # Inject prior extraction data into the query + for dep_schema_name in dependencies: + if dep_schema_name not in extractions.extractions: + raise ValueError(f"Dependency '{dep_schema_name}' not found in extractions") + + if filter_unique_cols: + dep_extraction = filter_unique_columns(extractions[dep_schema_name]) + else: + dep_extraction = extractions[dep_schema_name] + + schema_json = get_extraction_schema_model( + schema_structure[dep_schema_name], + include_fields=dep_extraction.columns.tolist(), + singleton=True, + description_only=True, + ).schema() + schema_definition = json.dumps(drop_type_def_from_schema_json(schema_json)) + prompt += ( + f"###{dep_schema_name}###\n" + f"Schema:\n" + f"{schema_definition}\n" + f"Data:\n" + f"{dep_extraction.to_json(orient='index')}\n\n" + ) + + return prompt + + +def create_completion_prompt( + schema: pa.DataFrameSchema, + extractions: PaperExtraction, + include_fields: List[str], + filter_unique_cols=True, + extra_prompt: Optional[str] = None, +) -> str: + assert schema.name in extractions.extractions, f"Schema '{schema.name}' not found in extractions" + prompt = create_extraction_prompt(schema, extractions, filter_unique_cols) + existing_extraction = extractions[schema.name] + + note = f"Note: {extra_prompt}\n" if extra_prompt else "" + + prompt += ( + f'Please complete the following `{schema.name}` table by extracting the {include_fields} fields ' + f'for the following {len(existing_extraction)} entries. The rows you\'re filling in may not match the same order ' + f'as the rows in the provided context, so be sure to match the correct rows based on the existing values.\n' + f'{note}' + f'###{schema.name}###\n' + f'Data:\n' + f'{existing_extraction.reset_index().to_json(orient="index")}\n\n' + ) + + return prompt + + +CHAT_SYSTEM_PROMPT = ( + "You are a research assistant helping the user perform data extraction from a research paper.\n" + "Always answer the query using the provided context information, " + "and not prior knowledge.\n" + "Some rules to follow:\n" + "1. Give a list of `header` from the context at the end of your answer to help the user identify the section in the paper.\n" + "2. If the information is not avilable or you are unsure, state that the information is unavailable.\n" +) + +TEXT_QA_SYSTEM_PROMPT = ChatMessage( + content=( + "You are an expert Q&A system that is trusted around the world.\n" + "Always answer the query using the provided context information, " + "and not prior knowledge.\n" + ), + role=MessageRole.SYSTEM, +) + +TEXT_QA_PROMPT_TMPL_MSGS = [ + TEXT_QA_SYSTEM_PROMPT, + ChatMessage( + content=( + "Context information is below.\n" + "---------------------\n" + "{context_str}\n" + "---------------------\n" + "Given the context information and not prior knowledge, " + "answer the query. Always include your references using `header` from the context to help the user identify the section in the paper\n" + "Query: {query_str}\n" + "Answer: " + ), + role=MessageRole.USER, + ), +] +DEFAULT_CHAT_PROMPT_TMPL = ChatPromptTemplate(message_templates=chat_prompts.TEXT_QA_PROMPT_TMPL_MSGS) diff --git a/argilla-v1/src/extralit_v1/extraction/query.py b/argilla-v1/src/extralit_v1/extraction/query.py new file mode 100644 index 000000000..3234968b6 --- /dev/null +++ b/argilla-v1/src/extralit_v1/extraction/query.py @@ -0,0 +1,103 @@ +import logging +from typing import List, Dict, Any, Optional, Union + +from llama_index.core.vector_stores import ( + MetadataFilter, + MetadataFilters, + FilterOperator, + FilterCondition, +) +from llama_index.vector_stores.weaviate.base import _to_weaviate_filter +from llama_index.vector_stores.weaviate.utils import validate_client, class_schema_exists +from weaviate import WeaviateClient +from weaviate.exceptions import WeaviateQueryError + +_LOGGER = logging.getLogger(__name__) + + +def get_nodes_metadata( + weaviate_client: WeaviateClient, + filters: Union[Dict[str, Any], MetadataFilters], + index_name: str = "LlamaIndexDocumentSections", + properties: Union[List, Dict] = ["header", "page_number", "type", "reference", "doc_id"], + limit: Optional[int] = None, +) -> List[Dict[str, Any]]: + """ + Query document nodes and metadata from Vector DB based on specified filters. + + Args: + weaviate_client (WeaviateClient): The Weaviate client object. + filters (Union[Dict[str, Any], MetadataFilters]): The filters to apply on the metadata. + It can be either a dictionary of key-value pairs or a MetadataFilters object. + index_name (str, optional): The name of the index to query. Defaults to 'LlamaIndexDocumentSections'. + properties (Union[List, Dict], optional): The properties to include in the query result. + It can be a list of property names or a dictionary of property names and their values. + Defaults to ['header', 'page_number', 'type', 'reference', 'doc_id']. + limit (Optional[int], optional): The maximum number of results to return. Defaults to None. + Returns: + List[Dict[str, Any]]: A list of dictionaries representing the metadata of the nodes. + Raises: + None + Examples: + # Example 1: Retrieve metadata with simple filters + filters = {'type': 'chapter', 'reference': 'ch01'} + metadata = get_nodes_metadata(weaviate_client, filters) + # Example 2: Retrieve metadata with complex filters + MetadataFilter(key='type', value='chapter', operator=FilterOperator.EQ), + MetadataFilter(key='reference', value=['ch01', 'ch02'], operator=FilterOperator.IN) + ], + metadata = get_nodes_metadata(weaviate_client, filters, limit=10) + """ + + validate_client(weaviate_client) + if not class_schema_exists(weaviate_client, index_name): + return [] + + if isinstance(filters, dict): + assert set(filters.keys()).issubset( + properties + ), f"Filters {list(filters)} must be a subset of properties {list(properties)}" + filters = MetadataFilters( + filters=[ + MetadataFilter(key=k, value=v, operator=FilterOperator.IN if isinstance(v, list) else FilterOperator.EQ) + for k, v in filters.items() + ], + condition=FilterCondition.AND, + ) + + collection = weaviate_client.collections.get(index_name) + + try: + query_result = collection.query.fetch_objects( + filters=_to_weaviate_filter(filters), + return_properties=properties, + limit=limit, + ) + + entries = [o.properties for o in query_result.objects] + return entries + + except WeaviateQueryError as wqe: + _LOGGER.error("Error while querying Weaviate: %s", wqe) + return [] + + +def vectordb_contains_any( + reference: str, + *, + filters: Optional[Dict[str, str]] = None, + weaviate_client: WeaviateClient = None, + index_name: str = "LlamaIndexDocumentSections", +) -> bool: + if weaviate_client is None: + return False + + nodes = get_nodes_metadata( + weaviate_client, + index_name=index_name, + filters={"reference": reference, **(filters or {})}, + properties=["doc_id", "reference", "type"], + limit=1, + ) + + return len(nodes) > 0 diff --git a/argilla-v1/src/extralit_v1/extraction/schema.py b/argilla-v1/src/extralit_v1/extraction/schema.py new file mode 100644 index 000000000..7d460d8c2 --- /dev/null +++ b/argilla-v1/src/extralit_v1/extraction/schema.py @@ -0,0 +1,233 @@ +import logging +from typing import List, Optional, Type, Union, Dict, Any + +import pandas as pd +import pandera as pa +from pydantic.v1 import BaseModel, Field, create_model + +from extralit_v1.extraction.utils import stringify_lists +from extralit_v1.extraction.staging import heal_json, to_df + + +class SchemaStructuredOutputResponseModel(BaseModel): + @classmethod + def parse_raw(cls, b, **kwargs): + healed_json_string = heal_json(b) + try: + output = super().parse_raw(healed_json_string, **kwargs) + except Exception as e: + logging.error(f"Error parsing {cls.__name__}: {e}\n" f'Given: "{healed_json_string}"') + return cls(items=[]) + + return output + + def to_df(self, *args, **kwargs) -> pd.DataFrame: + return to_df(self, *args, **kwargs) + + class Config: + validate_assignment = False + arbitrary_types_allowed = True + + +def clean_docstring(description: str) -> Optional[str]: + if description is None: + return None + + cleaned_description = description.strip().replace("\n", " ") + return cleaned_description + + +def pandera_dtype_to_python_type(pandera_dtype: Type[pa.typing.Series]) -> Type: + if isinstance(pandera_dtype, pa.DataType): + dtype = pandera_dtype.type + + if dtype == "int": + return Optional[Union[int, str]] + elif dtype == "float": + return Optional[Union[float, str]] + elif dtype == list: + return Optional[List[str]] + elif dtype == bool: + return Optional[bool] + elif dtype in ["O", "S", "U"]: + return Optional[str] + + return Optional[str] + + +def pandera_column_to_pydantic_field(column: pa.Column, validate_assignment=False, description_only=False) -> Field: + description = column.description or "" + + if description_only: + return Field(None, title=column.title, description=description) + + if column.checks: + description += "\nSpecifications:" + + validators = {} + extra = {} + + for check in column.checks: + if "greater_than_or_equal_to" == check.name: + validators["ge"] = check.statistics["min_value"] + if not validate_assignment: + description += f"\n{check.name}: {next(iter(check.statistics.values()), None)}" + elif "less_than_or_equal_to" == check.name: + validators["le"] = check.statistics["max_value"] + if not validate_assignment: + description += f"\n{check.name}: {next(iter(check.statistics.values()), None)}" + elif "less_than" == check.name: + validators["lt"] = check.statistics["max_value"] + if not validate_assignment: + description += f"\n{check.name}: {next(iter(check.statistics.values()), None)}" + elif "greater_than" == check.name: + validators["gt"] = check.statistics["min_value"] + if not validate_assignment: + description += f"\n{check.name}: {next(iter(check.statistics.values()), None)}" + + elif "str_matches" == check.name: + validators["regex"] = check.statistics["pattern"] + if not validate_assignment: + extra["str_matches"] = check.statistics["pattern"] + elif "str_length" == check.name and "min_value" in check.statistics: + validators["min_length"] = check.statistics["min_value"] + if not validate_assignment: + extra.setdefault("str_length", {})["min_value"] = check.statistics["min_value"] + elif "str_length" == check.name and "max_value" in check.statistics: + validators["max_length"] = check.statistics["max_value"] + if not validate_assignment: + extra.setdefault("str_length", {})["max_value"] = check.statistics["max_value"] + + elif "str_startswith" == check.name or "str_endswith" == check.name: + description += f'\n{check.name}: "{check.statistics["string"]}"' + + elif "suggestion" == check.name: + # description += f"\nSuggestion: {stringify_to_instructions(check.statistics['values'])}" + extra["suggestion"] = check.statistics["values"] + elif "isin" == check.name: + description += f"\nAllowed values: {stringify_lists(check.statistics['allowed_values'])}" + extra["allowed_values"] = check.statistics["allowed_values"] + elif "notin" == check.name: + description += f"\nForbidden values: {stringify_lists(check.statistics['forbidden_values'])}" + extra["forbidden_values"] = check.statistics["forbidden_values"] + + elif check.name == "multiselect": + description += f'\nmultivalues: "{check.statistics["delimiter"]}" delimited' + + else: + description += f"\n{check.name}: {check.statistics}" + + if description.endswith("\nSpecifications:"): + description = description.replace("\nSpecifications:", "") + + if not validate_assignment: + return Field( + None, title=column.title, description=description, **(dict(json_schema_extra=extra) if extra else {}) + ) + + return Field( + None, + title=column.title, + description=description, + **validators, + **(dict(json_schema_extra=extra) if extra else {}), + ) + + +def get_extraction_schema_model( + schema: pa.DataFrameSchema, + include_fields: List[str] = None, + exclude_fields: List[str] = None, + top_class: Optional[str] = None, + lower_class: Optional[str] = None, + singleton=False, + validate_assignment=False, + description_only=False, +) -> Type[SchemaStructuredOutputResponseModel]: + """ + Converts a Pandera DataFrameSchema to a Pydantic model. This model encodes checks and dtypes which will be used as a + prompt to guide an LLM's JSON output. The function dynamically creates Pydantic models with fields based on the + schema columns and index. If the schema is not a singleton, it creates a lower-level model for each row and a + top-level model that contains a list of the lower-level models. If the schema is a singleton, it creates a single + top-level model. + + Args: + schema (pa.DataFrameSchema): The Pandera DataFrameSchema to convert. + include_fields (List[str], optional): A list of column names to include in the Pydantic model. Defaults to None. + exclude_fields (List[str], optional): A list of column names to exclude in the Pydantic model. Defaults to None. + top_class (str, optional): The name of the top-level Pydantic model. Defaults to a plural form of the schema name. + lower_class (str, optional): The name of the lower-level Pydantic model. Defaults to the schema name. + singleton (bool, optional): Whether the schema represents a singleton. Defaults to False. + When True, the top-level model will not contain a list of lower-level model and only contain a single + value for each field. + validate_assignment: Whether to enforce validators on the LLM extractions, and potentially raising Exceptions. + Defaults to False. + description_only: Whether to include only the description in the Pydantic model. + Defaults to False. + + Returns: + Type[SchemaStructuredOutputResponseModel]: The Pydantic model that represents the schema definition and constraints. + """ + assert isinstance(schema, pa.DataFrameSchema), f"Expected DataFrameSchema, got {type(schema)}" + if top_class is None: + top_class = schema.name + ("s" if not singleton else "") + if lower_class is None: + lower_class = schema.name + + # Dynamically create fields for the lower-level model based on schema columns + columns = { + field_name: ( + pandera_dtype_to_python_type(column.dtype), + pandera_column_to_pydantic_field( + column, validate_assignment=validate_assignment, description_only=description_only + ), + ) + for field_name, column in schema.columns.items() + if not include_fields or field_name in include_fields + } + for field_name in include_fields or []: + if field_name not in columns: + columns[field_name] = (Optional[str], Field(None, title=field_name)) + + # Add fields from schema.index + index_fields = schema.index.indexes if hasattr(schema.index, "indexes") else [schema.index] + indexes = { + index.name: ( + pandera_dtype_to_python_type(index.dtype), + pandera_column_to_pydantic_field(index, validate_assignment=validate_assignment), + ) + for index in index_fields + if (index and index.name) + } + if exclude_fields: + columns = {k: v for k, v in columns.items() if k not in exclude_fields} + indexes = {k: v for k, v in indexes.items() if k not in exclude_fields} + + if not singleton: + lower_level_model = create_model(__model_name=lower_class, **indexes, **columns) + top_level_model = create_model( + __model_name=top_class, + __base__=SchemaStructuredOutputResponseModel, + items=(List[lower_level_model], Field(default_factory=list)), + ) + + else: + top_level_model = create_model( + __model_name=top_class, __base__=SchemaStructuredOutputResponseModel, **indexes, **columns + ) + + top_level_model.__doc__ = clean_docstring(schema.description) + + return top_level_model + + +def drop_type_def_from_schema_json(schema_json: Dict[str, Any]) -> Dict[str, Any]: + for key in schema_json.get("properties", []): + schema_json["properties"][key].pop("type", None) + schema_json["properties"][key].pop("anyOf", None) + for definition in schema_json.get("definitions", {}).values(): + for key in definition.get("properties", {}): + definition["properties"][key].pop("type", None) + definition["properties"][key].pop("anyOf", None) + + return schema_json diff --git a/argilla-v1/src/extralit_v1/extraction/staging.py b/argilla-v1/src/extralit_v1/extraction/staging.py new file mode 100644 index 000000000..4ab076ff6 --- /dev/null +++ b/argilla-v1/src/extralit_v1/extraction/staging.py @@ -0,0 +1,77 @@ +import json +import logging +from typing import Dict, Any, List + +import pandas as pd +from json_repair import repair_json + + +def list_to_str(x): + if isinstance(x, (list, tuple)): + if len(x) >= 1: + return ", ".join(x) + elif len(x) == 0: + return None + return x + + +def to_df(model, *args, **kwargs) -> pd.DataFrame: + items = model.dict(exclude_none=True)["items"] + dtypes = generate_dtypes(model, subset=items[0].keys()) if len(items) > 0 else {} + df = pd.DataFrame(items).astype(dtypes, errors="ignore") + + # Convert any list values to string delimited by comma + for col in df.select_dtypes(include=["object"]).columns: + df[col] = df[col].map(list_to_str) + + zero_value_columns = [col for col in df.columns if df[col].eq(0).all() or df[col].isna().all()] + df = df.drop(columns=zero_value_columns) + + return df + + +def generate_dtypes(model, subset: List[str] = None) -> Dict[str, Any]: + type_mapping = { + int: "Int64", + float: "Float32", + str: "string", + bool: "boolean", + } + + dtypes = {} + if hasattr(model, "items") and isinstance(model.items, list) and len(model.items) > 0: + item_model = model.items[0] + else: + item_model = model + + if not hasattr(item_model, "__fields__"): + return {} + + for name, field in item_model.__fields__.items(): + if subset and name not in subset: + continue + + python_type = field.outer_type_ + pandas_dtype = type_mapping.get(python_type) + if pandas_dtype is not None: + dtypes[name] = pandas_dtype + + return dtypes + + +def heal_json(json_string: str, return_on_failure="{}") -> str: + try: + # Try to parse the JSON string + json.loads(json_string) + return json_string + + except json.JSONDecodeError: + logging.warning(f"Attempting to fix broken JSON: ...{json_string[-100:]}".replace("\n", "")) + try: + healed_json_string = repair_json(json_string) + # Check if the healed JSON string is valid + json.loads(healed_json_string) + return healed_json_string + except Exception as e: + logging.info(f"Failed to repair JSON: {e}. Returning '{return_on_failure}'.") + return return_on_failure diff --git a/argilla-v1/src/extralit_v1/extraction/storage.py b/argilla-v1/src/extralit_v1/extraction/storage.py new file mode 100644 index 000000000..4329fe976 --- /dev/null +++ b/argilla-v1/src/extralit_v1/extraction/storage.py @@ -0,0 +1,45 @@ +import os +from typing import Optional + +from llama_index.core import StorageContext +from extralit_v1.extraction.vector_store import WeaviateVectorStore, create_default_schema +from llama_index.vector_stores.weaviate.utils import class_schema_exists, NODE_SCHEMA, validate_client +from weaviate import Client, WeaviateClient + + +def get_storage_context( + weaviate_client: Optional[WeaviateClient] = None, + persist_dir: Optional[str] = None, + index_name: Optional[str] = None, +) -> StorageContext: + """ + Create a StorageContext given a persist directory. + + Args: + persist_dir: str + The directory where the index is persisted. + + Returns: + StorageContext + The created StorageContext. + """ + kwargs = {} + if weaviate_client: + assert index_name + validate_client(weaviate_client) + schema_exists = class_schema_exists(client=weaviate_client, class_name=index_name) + if not schema_exists: + create_default_schema(client=weaviate_client, class_name=index_name) + + vector_store = WeaviateVectorStore(weaviate_client=weaviate_client, index_name=index_name, text_key="text") + kwargs["vector_store"] = vector_store + + elif persist_dir: + assert os.path.exists(persist_dir) + kwargs["persist_dir"] = persist_dir + + else: + raise ValueError("Either weaviate_client or persist_dir must be given") + + storage_context = StorageContext.from_defaults(**kwargs) + return storage_context diff --git a/argilla-v1/src/extralit_v1/extraction/utils.py b/argilla-v1/src/extralit_v1/extraction/utils.py new file mode 100644 index 000000000..7411b291f --- /dev/null +++ b/argilla-v1/src/extralit_v1/extraction/utils.py @@ -0,0 +1,61 @@ +import logging +from typing import Union, List, Dict + +import pandas as pd +import pandera as pa +from llama_index.core import Response + +_LOGGER = logging.getLogger(__name__) + + +def convert_response_to_dataframe(response: Response) -> pd.DataFrame: + try: + df: pd.DataFrame = response.response.to_df() + except AttributeError as ae: + _LOGGER.error( + f"""Failed to convert response to DataFrame: {ae} + Response: {response.response} + Source nodes: {len(response.source_nodes)} + """ + ) + df = pd.DataFrame() + return df + + +def generate_reference_columns(df: pd.DataFrame, schema: pa.DataFrameSchema): + if schema.index is None: + return df + + index_names = [index.name.lower() for index in schema.index.indexes] if hasattr(schema.index, "indexes") else [] + for index_name in index_names: + if index_name not in df.columns: + df[index_name] = "NOTMATCHED" + if index_names: + df = df.set_index(index_names, verify_integrity=False) + return df + + +def filter_unique_columns(df: pd.DataFrame) -> pd.DataFrame: + """ + Drop columns that have the same value in all rows. + """ + if len(df) > 1: + return df.dropna(axis="columns", how="all").loc[:, (df.astype(str).nunique() > 1)] + else: + return df + + +def stringify_lists(obj: Union[List, Dict], conjunction="or") -> str: + if isinstance(obj, dict): + items = list(obj) + elif isinstance(obj, list): + items = obj + else: + return obj.__repr__() + + if len(items) > 2: + repr_str = ", ".join(str(item) for item in items[:-1]) + f", {conjunction} " + str(items[-1]) + else: + repr_str = ", ".join(str(item) for item in items) + + return repr_str diff --git a/argilla-v1/src/extralit_v1/extraction/vector_index.py b/argilla-v1/src/extralit_v1/extraction/vector_index.py new file mode 100644 index 000000000..48b710a46 --- /dev/null +++ b/argilla-v1/src/extralit_v1/extraction/vector_index.py @@ -0,0 +1,237 @@ +import logging +import os.path +from collections import Counter +from os.path import join +from typing import Optional, Literal +import warnings + +import argilla_v1 as rg +from extralit_v1.storage.files import StorageType +import pandas as pd +from llama_index.core import VectorStoreIndex, load_index_from_storage, global_handler +from llama_index.core.node_parser import SentenceSplitter, JSONNodeParser +from llama_index.core.service_context import ServiceContext +from llama_index.core.storage import StorageContext +from llama_index.core.vector_stores import SimpleVectorStore, MetadataFilters, MetadataFilter, FilterOperator +from llama_index.embeddings.openai import OpenAIEmbeddingMode, OpenAIEmbedding +from llama_index.llms.openai import OpenAI +from weaviate import WeaviateClient + +from extralit_v1.extraction.chunking import create_nodes +from extralit_v1.extraction.query import vectordb_contains_any +from extralit_v1.extraction.storage import get_storage_context +from extralit_v1.extraction.vector_store import WeaviateVectorStore + +DEFAULT_RETRIEVAL_MODE = OpenAIEmbeddingMode.TEXT_SEARCH_MODE +_LOGGER = logging.getLogger(__name__) +warnings.filterwarnings("ignore", category=DeprecationWarning) + + +def create_local_index( + paper: pd.Series, + preprocessing_path="data/preprocessing/nougat/", + preprocessing_dataset: rg.FeedbackDataset = None, + persist_dir: Optional[str] = None, + embed_model="text-embedding-3-small", + dimensions=1536, + retrieval_mode=DEFAULT_RETRIEVAL_MODE, + chunk_size=4096, + chunk_overlap=200, + verbose=True, +) -> VectorStoreIndex: + text_nodes, table_nodes = create_nodes( + paper, + preprocessing_path=preprocessing_path, + preprocessing_dataset=preprocessing_dataset, + storage_type=StorageType.FILE, + ) + + _LOGGER.info( + f"Creating index with {len(text_nodes)} text and {len(table_nodes)} table segments, `persist_dir={persist_dir}`" + ) + + storage_context = get_storage_context(persist_dir=persist_dir) + embedding_model = OpenAIEmbedding( + mode=retrieval_mode, + model=embed_model, + dimensions=dimensions, + ) + + if global_handler and hasattr(global_handler, "set_trace_params"): + global_handler.set_trace_params(name=f"embed-{paper.name}", tags=[paper.name]) + + embed_model_context = ServiceContext.from_defaults( + embed_model=embedding_model, + node_parser=SentenceSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap), + ) + index = VectorStoreIndex.from_documents( + text_nodes, storage_context=storage_context, service_context=embed_model_context + ) + + index.insert_nodes(table_nodes, node_parser=JSONNodeParser(chunk_size=chunk_size, chunk_overlap=chunk_overlap)) + + if persist_dir and not storage_context.vector_store: + assert os.path.exists(persist_dir) + index.storage_context.persist(persist_dir) + + if verbose: + nodes_counts = Counter([doc.metadata["header"] for doc in index.docstore.docs.values()]) + nodes_counts = [(header, count) for header, count in nodes_counts.most_common() if count > 1] + print(pd.DataFrame(nodes_counts, columns=["header", "n_chunks"])) if nodes_counts else None + + return index + + +def create_vector_index( + paper: pd.Series, + weaviate_client: WeaviateClient, + preprocessing_dataset: Optional[rg.FeedbackDataset] = None, + preprocessing_path="data/preprocessing/nougat/", + index_name: Optional[str] = "LlamaIndexDocumentSections", + embed_model="text-embedding-3-small", + dimensions=1536, + retrieval_mode=DEFAULT_RETRIEVAL_MODE, + overwrite: Literal[True, "text", "table", "figure"] = "table", + chunk_size=4096, + chunk_overlap=200, + storage_type: StorageType = StorageType.FILE, + bucket_name: Optional[str] = None, + verbose=True, +) -> VectorStoreIndex: + """ + Creates a VectorStoreIndex for a given paper and loads it into a vector db. + + Args: + paper (pd.Series): The paper to be indexed. + weaviate_client (WeaviateClient): The Weaviate client to use. + preprocessing_dataset (Optional[rg.FeedbackDataset]): + The preprocessing dataset to use. Defaults to None. + If given, the TableSegments will be loaded from the Argilla dataset with users' annotations. If None, + the TableSegments will be loaded from the preprocessed table extractions locally from `preprocessing_path`. + preprocessing_path (str): The path to the preprocessing data. Defaults to 'data/preprocessing/nougat/'. + index_name (Optional[str]): The name of the index. Defaults to "LlamaIndexDocumentSections". + embed_model (str): The model to use for embedding documents. Defaults to 'text-embedding-3-small'. + dimensions (int): The dimensions of the embedding model. Defaults to 1536. + retrieval_mode (str): The retrieval mode of the embedding model. Defaults to DEFAULT_RETRIEVAL_MODE. + overwrite (Literal[True, 'text', 'table', 'figure']): The type of nodes to overwrite. Defaults to True, + which overwrites all nodes for the reference. + chunk_size (int): The size of the chunks to split the text into. Defaults to 4096. + chunk_overlap (int): The size of the overlap between chunks. Defaults to 200. + storage_type (StorageType): The storage type to use. Defaults to StorageType.FILE. + bucket_name (Optional[str]): The name of the S3 bucket (i.e. workspace name) to use. Defaults to None. + verbose (bool): Whether to print verbose output. Defaults to True. + + Returns: + VectorStoreIndex: The loaded VectorStoreIndex. + """ + + text_nodes, table_nodes = create_nodes( + paper, + preprocessing_path=preprocessing_path, + preprocessing_dataset=preprocessing_dataset, + storage_type=storage_type, + bucket_name=bucket_name, + ) + + if global_handler and hasattr(global_handler, "set_trace_params"): + global_handler.set_trace_params(name=f"embed-{paper.name}", tags=[paper.name]) + + vector_store = WeaviateVectorStore(weaviate_client=weaviate_client, index_name=index_name) + has_existing_node = vectordb_contains_any(paper.name, weaviate_client=weaviate_client, index_name=index_name) + if has_existing_node and overwrite: + delete_filters = [MetadataFilter(key="reference", value=paper.name, operator=FilterOperator.EQ)] + if isinstance(overwrite, str): + delete_filters.append(MetadataFilter(key="type", value=overwrite, operator=FilterOperator.EQ)) + vector_store.delete_nodes( + filters=MetadataFilters( + filters=delete_filters, + ) + ) + + embedding_model = OpenAIEmbedding(mode=retrieval_mode, model=embed_model, dimensions=dimensions) + embed_model_context = ServiceContext.from_defaults( + embed_model=embedding_model, + node_parser=SentenceSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap), + ) + + if has_existing_node and not overwrite: + _LOGGER.info(f"Skipping existing index for {paper.name}") + return VectorStoreIndex.from_vector_store(vector_store, service_context=embed_model_context) + + loaded_index = VectorStoreIndex.from_vector_store(vector_store, service_context=embed_model_context) + for node in text_nodes: + if has_existing_node and overwrite != "text": + continue + loaded_index.insert(node) + for node in table_nodes: + if has_existing_node and overwrite != "table": + continue + loaded_index.insert(node, node_parser=JSONNodeParser(chunk_size=chunk_size, chunk_overlap=chunk_overlap)) + + if verbose: + nodes_counts = Counter([doc.metadata["header"] for doc in loaded_index.docstore.docs.values()]) + nodes_counts = [(header, count) for header, count in nodes_counts.most_common() if count > 1] + print(pd.DataFrame(nodes_counts, columns=["header", "n_chunks"])) if nodes_counts else None + + return loaded_index + + +def load_index( + paper: pd.Series, + llm_model="gpt-4o", + embed_model="text-embedding-3-small", + weaviate_client: Optional[WeaviateClient] = None, + index_name: Optional[str] = "LlamaIndexDocumentSections", + persist_dir="data/interim/vectorstore/", + **kwargs, +) -> VectorStoreIndex: + """ + Creates or loads a VectorStoreIndex for a given paper. + + This function will either create a new VectorStoreIndex by processing the given paper, or load an existing one from + the specified directory. If the `reindex` parameter is set to True, the function will reindex the paper even if an + existing index is found. + + Args: + paper (pd.Series): The paper to be indexed. + llm_model (str, optional): The model to use for extraction. Defaults to 'gpt-3.5-turbo'. + embed_model (str, optional): The model to use for embedding documents. Defaults to 'text-embedding-3-small'. + weaviate_client (Client, optional): The Weaviate client to use. Defaults to None. + persist_dir (str, optional): The directory where the index is persisted. Defaults to 'data/interim/vectorstore/'. + + Returns: + VectorStoreIndex: The created or loaded VectorStoreIndex. + """ + # Load the existing index + storage_context = get_storage_context( + weaviate_client=weaviate_client, index_name=index_name, persist_dir=join(persist_dir, paper.name, embed_model) + ) + llm = OpenAI(model=llm_model, temperature=0.0, max_retries=3, streaming=True) + service_context = ServiceContext.from_defaults(llm=llm) + + if not isinstance(storage_context.vector_store, SimpleVectorStore): + index = VectorStoreIndex.from_vector_store(storage_context.vector_store, service_context=service_context) + else: + index = load_index_from_storage(storage_context, service_context=service_context) + + return index + + +def load_index_retriever( + paper: pd.Series, + similarity_top_k=3, + embed_model="text-embedding-3-small", + vectorstore_path="data/interim/vectorstore/", + retrieval_mode=DEFAULT_RETRIEVAL_MODE, + **kwargs, +): + persist_dir = join(vectorstore_path, paper.name, embed_model) + storage_context = StorageContext.from_defaults(persist_dir=persist_dir) + + llm = OpenAIEmbedding(model=embed_model, mode=retrieval_mode) + service_context = ServiceContext.from_defaults(embed_model=llm) + + index = load_index_from_storage(storage_context, service_context=service_context) + retriever = index.as_retriever(similarity_top_k=similarity_top_k, **kwargs) + + return retriever diff --git a/argilla-v1/src/extralit_v1/extraction/vector_store.py b/argilla-v1/src/extralit_v1/extraction/vector_store.py new file mode 100644 index 000000000..e62bd766c --- /dev/null +++ b/argilla-v1/src/extralit_v1/extraction/vector_store.py @@ -0,0 +1,252 @@ +"""Weaviate Vector store index. + +An index that is built on top of an existing vector store. + +""" + +import logging +from typing import Any, List, Dict, Optional, Union, Callable + +import weaviate # noqa +import weaviate.classes as wvc +from llama_index.core.schema import BaseNode +from llama_index.core.vector_stores.types import ( + MetadataFilters, + VectorStoreQuery, + VectorStoreQueryMode, + VectorStoreQueryResult, + FilterOperator, +) +from llama_index.vector_stores.weaviate import WeaviateVectorStore as WeaviateVectorStoreV0_10_0 +from llama_index.vector_stores.weaviate.utils import ( + get_all_properties, + get_node_similarity, + to_node, + validate_client, +) + +_LOGGER = logging.getLogger(__name__) + +NODE_SCHEMA: List[Dict] = [ + { + "dataType": ["text"], + "description": "Text property", + "name": "text", + }, + { + "dataType": ["text"], + "description": "The ref_doc_id of the Node", + "name": "ref_doc_id", + }, + { + "dataType": ["text"], + "description": "node_info (in JSON)", + "name": "node_info", + }, + { + "dataType": ["text"], + "description": "The relationships of the node (in JSON)", + "name": "relationships", + }, + { + "dataType": ["text"], + "description": "The reference of the Node", + "name": "reference", + }, + { + "dataType": ["text"], + "description": "The type of the Node", + "name": "type", + }, + { + "dataType": ["text"], + "description": "The doc_id of the Node", + "name": "doc_id", + }, +] + + +def create_default_schema(client: Any, class_name: str) -> None: + """Create default schema.""" + validate_client(client) + class_schema = { + "class": class_name, + "description": f"Class for {class_name}", + "properties": NODE_SCHEMA, + "vectorIndexType": "flat", + # "multiTenancyConfig": {"enabled": True}, + } + client.collections.create_from_dict(class_schema) + + +def _transform_weaviate_filter_condition(condition: str) -> Callable: + """Translate standard metadata filter op to Chroma specific spec.""" + if condition == "and": + return wvc.query.Filter.all_of + elif condition == "or": + return wvc.query.Filter.any_of + else: + raise ValueError(f"Filter condition {condition} not supported") + + +def _transform_weaviate_filter_operator(operator: FilterOperator) -> str: + """Translate standard metadata filter operator to Weaviate specific spec. + See https://weaviate.io/developers/weaviate/api/graphql/filters#filter-structure + """ + if operator == FilterOperator.NE: + return "not_equal" + elif operator == FilterOperator.EQ: + return "equal" + elif operator == FilterOperator.GT: + return "greater_than" + elif operator == FilterOperator.LT: + return "less_than" + elif operator == FilterOperator.GTE: + return "greater_or_equal" + elif operator == FilterOperator.LTE: + return "less_or_equal" + elif operator == FilterOperator.IN: + return "contains_any" + elif operator == FilterOperator.ALL: + return "contains_all" + elif operator == FilterOperator.TEXT_MATCH: + return "like" + else: + raise ValueError(f"Filter operator {operator} not supported") + + +def _to_weaviate_filter( + standard_filters: MetadataFilters, +) -> Union[wvc.query.Filter, List[wvc.query.Filter]]: + filters_list = [] + condition = standard_filters.condition or "and" + condition = _transform_weaviate_filter_condition(condition) + + if standard_filters.filters: + for filter in standard_filters.filters: + filters_list.append( + getattr( + wvc.query.Filter.by_property(filter.key), + _transform_weaviate_filter_operator(filter.operator), + )(filter.value) + ) + else: + return {} + + if len(filters_list) == 1: + # If there is only one filter, return it directly + return filters_list[0] + + return condition(filters_list) + + +class WeaviateVectorStore(WeaviateVectorStoreV0_10_0): + def get_nodes( + self, node_ids: Optional[List[str]] = None, filters: Optional[MetadataFilters] = None + ) -> List[BaseNode]: + collection = self._client.collections.get(self.index_name) + all_properties = get_all_properties(self._client, self.index_name) + + if filters is not None: + filters = _to_weaviate_filter(filters) + + # list of documents to constrain search + if node_ids is not None: + filters = wvc.query.Filter.by_property("id").contains_any(node_ids) + + query_result = collection.query.fetch_objects( + filters=filters, + return_properties=all_properties, + include_vector=False, + ) + + entries = [to_node(o.__dict__) for o in query_result.objects] + return entries + + def delete_nodes( + self, + node_ids: Optional[List[str]] = None, + filters: Optional[MetadataFilters] = None, + **delete_kwargs: Any, + ) -> None: + collection = self._client.collections.get(self.index_name) + + if node_ids is not None: + filters = wvc.query.Filter.by_property("id").contains_any(node_ids) + elif filters is not None: + filters = _to_weaviate_filter(filters) + else: + raise ValueError("Either node_ids or filters must be provided") + + results = collection.data.delete_many(where=filters, verbose=True) + if results.objects: + _LOGGER.debug(f"Deleted {len(results.objects)} nodes") + + def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult: + """Query index for top k most similar nodes.""" + all_properties = get_all_properties(self._client, self.index_name) + collection = self._client.collections.get(self.index_name) + filters = None + + # list of documents to constrain search + if query.doc_ids: + filters = wvc.query.Filter.by_property("doc_id").contains_any(query.doc_ids) + + if query.node_ids: + filters = wvc.query.Filter.by_property("id").contains_any(query.node_ids) + + return_metatada = wvc.query.MetadataQuery(distance=True, score=True) + + vector = query.query_embedding + similarity_key = "distance" + if query.mode == VectorStoreQueryMode.DEFAULT: + _LOGGER.debug("Using vector search") + if vector is not None: + alpha = 1 + elif query.mode == VectorStoreQueryMode.HYBRID: + _LOGGER.debug(f"Using hybrid search with alpha {query.alpha}") + similarity_key = "score" + if vector is not None and query.query_str: + alpha = query.alpha + + if query.filters is not None: + filters = _to_weaviate_filter(query.filters) + elif "filter" in kwargs and kwargs["filter"] is not None: + filters = kwargs["filter"] + + limit = query.similarity_top_k + _LOGGER.debug(f"Using limit of {query.similarity_top_k}") + + # execute query + try: + query_result = collection.query.hybrid( + query=query.query_str, + vector=vector, + alpha=alpha, + limit=limit, + filters=filters, + return_metadata=return_metatada, + return_properties=all_properties, + include_vector=True, + ) + except weaviate.exceptions.WeaviateQueryError as e: + raise ValueError(f"Invalid query, got errors: {e.message}") + + # parse results + + entries = query_result.objects + + similarities = [] + nodes: List[BaseNode] = [] + node_ids = [] + + for i, entry in enumerate(entries): + if i < query.similarity_top_k: + entry_as_dict = entry.__dict__ + similarities.append(get_node_similarity(entry_as_dict, similarity_key)) + nodes.append(to_node(entry_as_dict, text_key=self.text_key)) + node_ids.append(nodes[-1].node_id) + else: + break + + return VectorStoreQueryResult(nodes=nodes, ids=node_ids, similarities=similarities) diff --git a/argilla/src/extralit/metrics/__init__.py b/argilla-v1/src/extralit_v1/metrics/__init__.py similarity index 100% rename from argilla/src/extralit/metrics/__init__.py rename to argilla-v1/src/extralit_v1/metrics/__init__.py diff --git a/argilla-v1/src/extralit_v1/metrics/extraction.py b/argilla-v1/src/extralit_v1/metrics/extraction.py new file mode 100644 index 000000000..a73ef4a39 --- /dev/null +++ b/argilla-v1/src/extralit_v1/metrics/extraction.py @@ -0,0 +1,223 @@ +import io +import logging +from collections import defaultdict +from typing import Optional, List, Dict, Union, Literal + +import pandas as pd + +from extralit_v1.extraction.models.paper import PaperExtraction +from extralit_v1.metrics.grits import grits_from_html +from extralit_v1.metrics.utils import harmonize_columns, reorder_rows, convert_metrics_to_df + + +def grits_from_batch( + true_batch: Dict[str, PaperExtraction], + pred_batch: Dict[str, PaperExtraction], + exclude_columns: List[str] = [], + pairwise=False, + index_names: List[str] = ["Reference", "Schema", "Field"], + compute_mean: Optional[str] = None, + **kwargs, +) -> pd.DataFrame: + metrics = defaultdict(lambda: {}) + + if pairwise: + for true_key in true_batch: + for pred_key in pred_batch: + true_extractions = true_batch[true_key] + pred_extractions = pred_batch[pred_key] + if not pred_extractions or not true_extractions: + continue + + outputs = grits_paper(true_extractions, pred_extractions, exclude_columns=exclude_columns, **kwargs) + if outputs: + metrics[(true_key, pred_key)] = pd.concat({str(k): v for k, v in outputs.items()}, axis=1).T + else: + for batch in set(true_batch) & set(pred_batch): + true_extractions = true_batch[batch] + pred_extractions = pred_batch[batch] + if not pred_extractions or not true_extractions: + continue + + if isinstance(true_extractions, PaperExtraction): + outputs = grits_paper(true_extractions, pred_extractions, exclude_columns=exclude_columns, **kwargs) + metrics[batch] = pd.concat({str(k): v for k, v in outputs.items()}, axis=1).T + elif isinstance(true_extractions, list) and isinstance(true_extractions[0], pd.DataFrame): + outputs = grits_multi_tables(true_extractions, pred_extractions, **kwargs) + metrics[batch] = outputs + else: + raise ValueError(f"Invalid type for true_extractions: {type(true_extractions)}") + + metrics_df = pd.concat(metrics, axis=0) + metrics_df.index.names = index_names[: metrics_df.index.nlevels] + + if compute_mean: + aggregated = metrics_df.groupby(compute_mean).mean() + return aggregated + + return metrics_df + + +def grits_paper( + true_extractions: PaperExtraction, + pred_extractions: PaperExtraction, + exclude_columns=["Site"], + verbose=False, + metrics: List[Literal["top", "con", "upper_bound", "alignment"]] = ["con"], + **kwargs, +) -> Dict[str, Union[pd.Series, Dict[str, float]]]: + exclude_columns = exclude_columns or [] + + output = {} + + for schema_name in true_extractions.schemas.ordering: + if schema_name not in true_extractions: + output[schema_name] = None + continue + + schema_deps = true_extractions.schemas.upstream_dependencies[schema_name] + if any(schema_name not in pred_extractions for schema_name in schema_deps): + logging.warning(f"Missing extractions for {schema_deps} in {pred_extractions.__repr__()}") + + ref_columns = [ + col + for dep in schema_deps + for col in true_extractions.schemas.index_names(dep) + true_extractions.schemas.columns(dep) + ] + + output[schema_name] = grits_from_pandas( + true_extractions.get_joined_data(schema_name), + pred_extractions.get_joined_data(schema_name) if schema_name in pred_extractions else pd.DataFrame(), + index_columns=ref_columns, + only_columns=true_extractions.schemas.columns(schema_name), + exclude_columns=exclude_columns + ref_columns, + metrics=metrics, + format="series", + verbose=verbose, + **kwargs, + ) + + return output + + +def grits_from_pandas( + true_df: pd.DataFrame, + pred_df: pd.DataFrame, + index_columns: Optional[List[str]] = None, + only_columns: Optional[List[str]] = None, + exclude_columns: Optional[List[str]] = None, + metrics: List[Literal["top", "con", "upper_bound", "alignment"]] = ["top", "con"], + reduce: Literal["table", "column"] = "table", + format: Literal["series", "dataframe", "None"] = None, + nan_value="NA", + verbose=False, +) -> pd.Series: + """ + Grid Table Similarity (GriTS) evaluation metric for data extraction task. + + Args: + true_df (pd.DataFrame): Ground truth table. + pred_df (pd.DataFrame): Predicted table. + index_columns (list): List of columns to sort both true_df and pred_df to the same row ordering. + only_columns (list): List of columns to include in evaluation. + exclude_columns (list): List of columns to exclude from evaluation. + metrics (list): List of metrics to compute. Subset of {'top', 'con', 'upper_bound'}. + reduce (str): One of {'table', 'column'}. + format (str): Output format. One of {None, 'series', 'dataframe'}. + verbose (bool): For debugging purposes, return preprocessed dataframes before + they're passed into `grits_from_html`. + """ + true_df = true_df.copy().dropna(axis="columns", how="all").dropna(axis=0, how="all") + pred_df = pred_df.copy().dropna(axis="columns", how="all").dropna(axis=0, how="all") + + if isinstance(exclude_columns, pd.Index): + exclude_columns = exclude_columns.tolist() + exclude_columns = exclude_columns or [] + + true_df, pred_df = harmonize_columns(true_df, pred_df) + true_df, pred_df = reorder_rows(true_df, pred_df, index_columns=index_columns, verbose=verbose) + + # Drop columns not applicable for the schema + if exclude_columns is not None: + true_df = true_df.drop(columns=exclude_columns, errors="ignore") + pred_df = pred_df.drop(columns=exclude_columns, errors="ignore") + + if only_columns is not None: + true_df = true_df.filter(only_columns, axis="columns") + pred_df = pred_df.filter(only_columns, axis="columns") + + if nan_value: + true_df = true_df.loc[:, (true_df != nan_value).any(axis=0)] + pred_df = pred_df.loc[:, (pred_df != nan_value).any(axis=0)] + + to_html_args = dict(index=False, na_rep="", float_format=lambda x: "%.0f" % x if x == round(x) else "%.2f" % x) + + if verbose: + print("\ngrits_from_pandas debug:\n", "only_columns", only_columns or true_df.columns.tolist()) + + if reduce == "table": + true_html = true_df.to_html(**to_html_args) + pred_html = pred_df.to_html(**to_html_args) + + outputs = grits_from_html(true_html, pred_html, metrics=metrics) + outputs = convert_metrics_to_df(outputs, format) + + elif reduce == "column": + outputs = {} + for col in true_df.columns.difference(exclude_columns or []): + true_html = true_df[[col]].to_html(**to_html_args) + pred_html = pred_df[[col]].to_html(**to_html_args) + outputs[col] = convert_metrics_to_df( + grits_from_html(true_html, pred_html, metrics=metrics), format="series" + ) + + outputs = pd.concat(outputs, axis=1) + else: + raise ValueError(f"Invalid value for reduce: {reduce}") + + if verbose >= 2: + outputs["true_df"], outputs["pred_df"] = true_df, pred_df + + return outputs + + +def grits_multi_tables( + true_tables: List[Union[pd.DataFrame, str]], + pred_tables: List[Union[pd.DataFrame, str]], + only_common_columns=True, + **kwargs, +) -> pd.DataFrame: + results: Dict[str, pd.Series] = defaultdict(dict) + + for i, (true_df, pred_df) in enumerate(zip(true_tables, pred_tables)): + try: + if isinstance(true_df, pd.DataFrame): + if only_common_columns: + only_columns = ( + true_df.columns.intersection(pred_df.columns) + .difference(kwargs.get("index_columns", [])) + .tolist() + ) + kwargs.pop("only_columns", None) + else: + only_columns = kwargs.pop("only_columns", None) + + results[i] = grits_from_pandas(true_df, pred_df, format="series", only_columns=only_columns, **kwargs) + + elif isinstance(true_df, str): + metrics = grits_from_html( + pd.read_html(io.StringIO(true_df))[0].to_html(index=False, na_rep=""), + pd.read_html(io.StringIO(pred_df))[0].to_html(index=False, na_rep=""), + **kwargs, + ) + results[i] = convert_metrics_to_df(metrics, format="series") + + except Exception as e: + logging.error(f"Failed to compute metrics for index {i}. \n{e}") + results[i] = None + + if not {str(k): v for k, v in results.items() if v is not None}: + return pd.DataFrame() + + metrics_df = pd.concat(results, axis=1).T + return metrics_df diff --git a/argilla-v1/src/extralit_v1/metrics/grits.py b/argilla-v1/src/extralit_v1/metrics/grits.py new file mode 100644 index 000000000..200b98553 --- /dev/null +++ b/argilla-v1/src/extralit_v1/metrics/grits.py @@ -0,0 +1,575 @@ +""" +MIT License +Copyright (c) Microsoft Corporation. +""" + +import itertools +import logging +import xml.etree.ElementTree as ET +from collections import defaultdict +from difflib import SequenceMatcher +from typing import Dict, List, Tuple, Callable + +import numpy as np +import pandas as pd +from bs4 import BeautifulSoup +from fitz import Rect + + +def compute_fscore(num_true_positives, num_true, num_positives) -> Tuple[float, float, float]: + """ + Compute the f-score or f-measure for a collection of predictions. + + Conventions: + - precision is 1 when there are no predicted instances + - recall is 1 when there are no true instances + - fscore is 0 when recall or precision is 0 + """ + if num_positives > 0: + precision = num_true_positives / num_positives + else: + precision = 1 + if num_true > 0: + recall = num_true_positives / num_true + else: + recall = 1 + + if precision + recall > 0: + fscore = 2 * precision * recall / (precision + recall) + else: + fscore = 0 + + return fscore, precision, recall + + +def initialize_DP(sequence1_length, sequence2_length): + """ + Helper function to initialize dynamic programming data structures. + """ + # Initialize DP tables + scores = np.zeros((sequence1_length + 1, sequence2_length + 1)) + pointers = np.zeros((sequence1_length + 1, sequence2_length + 1)) + + # Initialize pointers in DP table + for seq1_idx in range(1, sequence1_length + 1): + pointers[seq1_idx, 0] = -1 + + # Initialize pointers in DP table + for seq2_idx in range(1, sequence2_length + 1): + pointers[0, seq2_idx] = 1 + + return scores, pointers + + +def traceback(pointers): + """ + Dynamic programming traceback to determine the aligned indices + between the two sequences. + + Traceback convention: -1 = up, 1 = left, 0 = diag up-left + """ + seq1_idx = pointers.shape[0] - 1 + seq2_idx = pointers.shape[1] - 1 + aligned_sequence1_indices = [] + aligned_sequence2_indices = [] + while not (seq1_idx == 0 and seq2_idx == 0): + if pointers[seq1_idx, seq2_idx] == -1: + seq1_idx -= 1 + elif pointers[seq1_idx, seq2_idx] == 1: + seq2_idx -= 1 + else: + seq1_idx -= 1 + seq2_idx -= 1 + aligned_sequence1_indices.append(seq1_idx) + aligned_sequence2_indices.append(seq2_idx) + + aligned_sequence1_indices = aligned_sequence1_indices[::-1] + aligned_sequence2_indices = aligned_sequence2_indices[::-1] + + return aligned_sequence1_indices, aligned_sequence2_indices + + +def align_1d(sequence1, sequence2, reward_lookup, return_alignment=False): + """ + Dynamic programming alignment between two sequences, + with memoized rewards. + + Sequences are represented as indices into the rewards lookup table. + + Traceback convention: -1 = up, 1 = left, 0 = diag up-left + """ + sequence1_length = len(sequence1) + sequence2_length = len(sequence2) + + scores, pointers = initialize_DP(sequence1_length, sequence2_length) + + for seq1_idx in range(1, sequence1_length + 1): + for seq2_idx in range(1, sequence2_length + 1): + reward = reward_lookup[sequence1[seq1_idx - 1] + sequence2[seq2_idx - 1]] + diag_score = scores[seq1_idx - 1, seq2_idx - 1] + reward + skip_seq2_score = scores[seq1_idx, seq2_idx - 1] + skip_seq1_score = scores[seq1_idx - 1, seq2_idx] + + max_score = max(diag_score, skip_seq1_score, skip_seq2_score) + scores[seq1_idx, seq2_idx] = max_score + if diag_score == max_score: + pointers[seq1_idx, seq2_idx] = 0 + elif skip_seq1_score == max_score: + pointers[seq1_idx, seq2_idx] = -1 + else: # skip_seq2_score == max_score + pointers[seq1_idx, seq2_idx] = 1 + + score = scores[-1, -1] + + if not return_alignment: + return score + + # Traceback + sequence1_indices, sequence2_indices = traceback(pointers) + + return sequence1_indices, sequence2_indices, score + + +def align_2d_outer(true_shape, pred_shape, reward_lookup): + """ + Dynamic programming matrix alignment posed as 2D + sequence-of-sequences alignment: + Align two outer sequences whose entries are also sequences, + where the match reward between the inner sequence entries + is their 1D sequence alignment score. + + Traceback convention: -1 = up, 1 = left, 0 = diag up-left + """ + + scores, pointers = initialize_DP(true_shape[0], pred_shape[0]) + + for row_idx in range(1, true_shape[0] + 1): + for col_idx in range(1, pred_shape[0] + 1): + reward = align_1d( + [(row_idx - 1, tcol) for tcol in range(true_shape[1])], + [(col_idx - 1, prow) for prow in range(pred_shape[1])], + reward_lookup, + ) + diag_score = scores[row_idx - 1, col_idx - 1] + reward + same_row_score = scores[row_idx, col_idx - 1] + same_col_score = scores[row_idx - 1, col_idx] + + max_score = max(diag_score, same_col_score, same_row_score) + scores[row_idx, col_idx] = max_score + if diag_score == max_score: + pointers[row_idx, col_idx] = 0 + elif same_col_score == max_score: + pointers[row_idx, col_idx] = -1 + else: + pointers[row_idx, col_idx] = 1 + + score = scores[-1, -1] + + aligned_true_indices, aligned_pred_indices = traceback(pointers) + + return aligned_true_indices, aligned_pred_indices, score + + +def factored_2dmss( + true_cell_grid: np.ndarray, pred_cell_grid: np.ndarray, reward_function: Callable, return_substructures=False +) -> Tuple[float, float, float, float]: + """ + Factored 2D-MSS: Factored two-dimensional most-similar substructures + + This is a polynomial-time heuristic to computing the 2D-MSS of two matrices, + which is NP hard. + + A substructure of a matrix is a subset of its rows and its columns. + + The most similar substructures of two matrices, A and B, are the substructures + A' and B', where the sum of the similarity over all corresponding entries + A'(i, j) and B'(i, j) is greatest. + """ + pre_computed_rewards = {} + transpose_rewards = {} + for trow, tcol, prow, pcol in itertools.product( + range(true_cell_grid.shape[0]), + range(true_cell_grid.shape[1]), + range(pred_cell_grid.shape[0]), + range(pred_cell_grid.shape[1]), + ): + reward = reward_function(true_cell_grid[trow, tcol], pred_cell_grid[prow, pcol]) + + pre_computed_rewards[(trow, tcol, prow, pcol)] = reward + transpose_rewards[(tcol, trow, pcol, prow)] = reward + + num_pos = pred_cell_grid.shape[0] * pred_cell_grid.shape[1] + num_true = true_cell_grid.shape[0] * true_cell_grid.shape[1] + + true_row_nums, pred_row_nums, row_pos_match_score = align_2d_outer( + true_cell_grid.shape[:2], pred_cell_grid.shape[:2], pre_computed_rewards + ) + + true_column_nums, pred_column_nums, col_pos_match_score = align_2d_outer( + true_cell_grid.shape[:2][::-1], pred_cell_grid.shape[:2][::-1], transpose_rewards + ) + + if return_substructures: + true_substructure = true_cell_grid[true_row_nums, :][:, true_column_nums] + pred_substructure = pred_cell_grid[pred_row_nums, :][:, pred_column_nums] + + return true_substructure, pred_substructure + + pos_match_score_upper_bound = min(row_pos_match_score, col_pos_match_score) + upper_bound_score, _, _ = compute_fscore(pos_match_score_upper_bound, num_pos, num_true) + + positive_match_score = 0 + for true_row_num, pred_row_num in zip(true_row_nums, pred_row_nums): + for true_column_num, pred_column_num in zip(true_column_nums, pred_column_nums): + positive_match_score += pre_computed_rewards[(true_row_num, true_column_num, pred_row_num, pred_column_num)] + + fscore, precision, recall = compute_fscore(positive_match_score, num_true, num_pos) + + return fscore, precision, recall, upper_bound_score + + +def lcs_string(string1: str, string2: str) -> str: + s = SequenceMatcher(None, string1, string2) + lcs = "".join([string1[block.a : (block.a + block.size)] for block in s.get_matching_blocks()]) + return lcs + + +def lcs_similarity(string1: str, string2: str) -> float: + if len(string1) == 0 and len(string2) == 0: + return 1 + lcs = lcs_string(string1, string2) + return 2 * len(lcs) / (len(string1) + len(string2)) + + +def iou(bbox1, bbox2): + """ + Compute the intersection-over-union of two bounding boxes. + """ + intersection = Rect(bbox1).intersect(bbox2) + union = Rect(bbox1).include_rect(bbox2) + + union_area = union.get_area() + if union_area > 0: + return intersection.get_area() / union.get_area() + + return 0 + + +def cells_to_grid(cells, key="bbox"): + """ + Convert from a list of cells to a matrix of grid cell features. + This matrix representation is the input to GriTS. + + For key, use: + - 'bbox' for computing GriTS_Loc + - 'cell_text' for computing GriTS_Con + """ + if len(cells) == 0: + return [[]] + num_rows = max([max(cell["row_nums"]) for cell in cells]) + 1 + num_columns = max([max(cell["column_nums"]) for cell in cells]) + 1 + cell_grid = np.zeros((num_rows, num_columns)).tolist() + for cell in cells: + for row_num in cell["row_nums"]: + for column_num in cell["column_nums"]: + cell_grid[row_num][column_num] = cell[key] + + return cell_grid + + +def cells_to_relspan_grid(cells: List[Dict[str, List[int]]]) -> List[List[List[int]]]: + """ + Convert from a list of cells to the matrix of grid cell features + used for computing GriTS_Top. + """ + if len(cells) == 0: + return [[]] + num_rows = max([max(cell["row_nums"]) for cell in cells]) + 1 + num_columns = max([max(cell["column_nums"]) for cell in cells]) + 1 + cell_grid = np.zeros((num_rows, num_columns)).tolist() + for cell in cells: + min_row_num = min(cell["row_nums"]) + min_column_num = min(cell["column_nums"]) + max_row_num = max(cell["row_nums"]) + 1 + max_column_num = max(cell["column_nums"]) + 1 + for row_num in cell["row_nums"]: + for column_num in cell["column_nums"]: + cell_grid[row_num][column_num] = [ + min_column_num - column_num, + min_row_num - row_num, + max_column_num - column_num, + max_row_num - row_num, + ] + + return cell_grid + + +def get_spanning_cell_rows_and_columns(spanning_cells, rows, columns): + """ + Determine which grid cell locations (row-column) each spanning cell + corresponds to. + """ + matches_by_spanning_cell = [] + all_matches = set() + for spanning_cell in spanning_cells: + row_matches = set() + column_matches = set() + for row_num, row in enumerate(rows): + bbox1 = [spanning_cell["bbox"][0], row["bbox"][1], spanning_cell["bbox"][2], row["bbox"][3]] + bbox2 = Rect(spanning_cell["bbox"]).intersect(bbox1) + if bbox2.get_area() / Rect(bbox1).get_area() >= 0.5: + row_matches.add(row_num) + for column_num, column in enumerate(columns): + bbox1 = [column["bbox"][0], spanning_cell["bbox"][1], column["bbox"][2], spanning_cell["bbox"][3]] + bbox2 = Rect(spanning_cell["bbox"]).intersect(bbox1) + if bbox2.get_area() / Rect(bbox1).get_area() >= 0.5: + column_matches.add(column_num) + already_taken = False + this_matches = [] + for row_num in row_matches: + for column_num in column_matches: + this_matches.append((row_num, column_num)) + if (row_num, column_num) in all_matches: + already_taken = True + if not already_taken: + for match in this_matches: + all_matches.add(match) + matches_by_spanning_cell.append(this_matches) + row_nums = [elem[0] for elem in this_matches] + column_nums = [elem[1] for elem in this_matches] + row_rect = Rect() + for row_num in row_nums: + row_rect.include_rect(rows[row_num]["bbox"]) + column_rect = Rect() + for column_num in column_nums: + column_rect.include_rect(columns[column_num]["bbox"]) + spanning_cell["bbox"] = list(row_rect.intersect(column_rect)) + else: + matches_by_spanning_cell.append([]) + + return matches_by_spanning_cell + + +def output_to_dilatedbbox_grid(bboxes, labels, scores): + """ + Compute the matrix of grid cell features for GriTS_Loc but using the raw predicted + and ground truth bounding boxes, not the post-processed boxes. + + In the case of the itnrecal used in the PubTables-1M paper, these boxes are + *dilated*, which means they are larger than the actual ground truth boxes. + + Computing GriTS_Loc with dilated bounding boxes is probably not very useful + for itnrecal comparison but could be useful for understanding the behavior of + an individual itnrecal. + """ + rows = [{"bbox": bbox} for bbox, label in zip(bboxes, labels) if label == 2] + columns = [{"bbox": bbox} for bbox, label in zip(bboxes, labels) if label == 1] + spanning_cells = [{"bbox": bbox, "score": 1} for bbox, label in zip(bboxes, labels) if label in [4, 5]] + rows.sort(key=lambda x: x["bbox"][1] + x["bbox"][3]) + columns.sort(key=lambda x: x["bbox"][0] + x["bbox"][2]) + spanning_cells.sort(key=lambda x: -x["score"]) + cell_grid = [] + for row_num, row in enumerate(rows): + column_grid = [] + for column_num, column in enumerate(columns): + bbox = Rect(row["bbox"]).intersect(column["bbox"]) + column_grid.append(list(bbox)) + cell_grid.append(column_grid) + matches_by_spanning_cell = get_spanning_cell_rows_and_columns(spanning_cells, rows, columns) + for matches, spanning_cell in zip(matches_by_spanning_cell, spanning_cells): + for match in matches: + cell_grid[match[0]][match[1]] = spanning_cell["bbox"] + + return cell_grid + + +def grits_top(true_relative_span_grid, pred_relative_span_grid): + """ + Compute GriTS_Top given two matrices of cell relative spans. + + For the cell at grid location (i,j), let a(i,j) be its rowspan, + let β(i,j) be its colspan, let p(i,j) be the minimum row it occupies, + and let θ(i,j) be the minimum column it occupies. Its relative span is + bounding box [θ(i,j)-j, p(i,j)-i, θ(i,j)-j+β(i,j), p(i,j)-i+a(i,j)]. + + It gives the size and location of the cell each grid cell belongs to + relative to the current grid cell location, in grid coordinate units. + Note that for a non-spanning cell this will always be [0, 0, 1, 1]. + """ + return factored_2dmss(true_relative_span_grid, pred_relative_span_grid, reward_function=iou) + + +def grits_loc(true_bbox_grid, pred_bbox_grid): + """ + Compute GriTS_Loc given two matrices of cell bounding boxes. + """ + return factored_2dmss(true_bbox_grid, pred_bbox_grid, reward_function=iou) + + +def grits_con(true_text_grid, pred_text_grid): + """ + Compute GriTS_Con given two matrices of cell text strings. + """ + return factored_2dmss(true_text_grid, pred_text_grid, reward_function=lcs_similarity) + + +def remove_colgroup_tags(html_content: str) -> str: + soup = BeautifulSoup(html_content, "html.parser") + + # Find all 'colgroup' tags + colgroup_tags = soup.find_all("colgroup") + + # Remove each 'colgroup' tag + for tag in colgroup_tags: + tag.decompose() + + # Return the modified HTML as a string + return str(soup) + + +def make_html_table_homogeneous(html_content: str) -> str: + soup = BeautifulSoup(html_content, "html.parser") + + # Find all 'tr' tags (rows) + rows = soup.find_all("tr") + + # Determine the maximum number of cells in any row + max_cells = max(len(row.find_all(["td", "th"])) for row in rows) + + # Iterate over all rows + for row in rows: + cells = row.find_all(["td", "th"]) + num_cells = len(cells) + + # If a row has fewer cells than max_cells, add additional cells + if num_cells < max_cells: + for _ in range(max_cells - num_cells): + new_cell = soup.new_tag("td") # or 'th' if you want to add header cells + row.append(new_cell) + + # Return the modified HTML as a string + return str(soup) + + +def html_to_cells(table_html: str): + """ + Parse an HTML representation of a table into a list of cells. + """ + try: + table_html = str(BeautifulSoup(table_html, "html.parser")) + tree = ET.fromstring(table_html) + except Exception as e: + logging.error(f"html_to_cells: {e}\n{table_html}") + return None + + table_cells = [] + + occupied_columns_by_row = defaultdict(set) + current_row = -1 + + # Get all td tags + stack = [] + stack.append((tree, False)) + while len(stack) > 0: + current, in_header = stack.pop() + + if current.tag == "tr": + current_row += 1 + + if current.tag == "td" or current.tag == "th": + if "colspan" in current.attrib: + colspan = int(current.attrib["colspan"]) + else: + colspan = 1 + if "rowspan" in current.attrib: + rowspan = int(current.attrib["rowspan"]) + else: + rowspan = 1 + row_nums = list(range(current_row, current_row + rowspan)) + try: + max_occupied_column = max(occupied_columns_by_row[current_row]) + current_column = min( + set(range(max_occupied_column + 2)).difference(occupied_columns_by_row[current_row]) + ) + except: + current_column = 0 + column_nums = list(range(current_column, current_column + colspan)) + for row_num in row_nums: + occupied_columns_by_row[row_num].update(column_nums) + + cell_dict = dict() + cell_dict["row_nums"] = row_nums + cell_dict["column_nums"] = column_nums + cell_dict["is_column_header"] = current.tag == "th" or in_header + cell_dict["cell_text"] = " ".join(current.itertext()) + table_cells.append(cell_dict) + + children = list(current) + for child in children[::-1]: + stack.append((child, in_header or current.tag == "th" or current.tag == "thead")) + + return table_cells + + +def grits_from_html(true_html, pred_html, metrics=["top", "con"]) -> Dict[str, float]: + """ + Compute GriTS_Con and GriTS_Top for two HTML sequences. + """ + + outputs = {} + + # Convert HTML to list of cells + true_cells = html_to_cells(true_html) + pred_cells = html_to_cells(pred_html) + + # Convert lists of cells to matrices of grid cells + true_topology_grid = np.array(cells_to_relspan_grid(true_cells)) + pred_topology_grid = np.array(cells_to_relspan_grid(pred_cells)) + true_text_grid = np.array(cells_to_grid(true_cells, key="cell_text"), dtype=object) + pred_text_grid = np.array(cells_to_grid(pred_cells, key="cell_text"), dtype=object) + + # Compute GriTS_Top (topology) for ground truth and predicted matrices + if "top" in metrics: + ( + outputs["grits_top_f1"], + outputs["grits_top_precision"], + outputs["grits_top_recall"], + outputs["grits_top_upper_bound"], + ) = grits_top(true_topology_grid, pred_topology_grid) + + if "con" in metrics: + # Compute GriTS_Con (text content) for ground truth and predicted matrices + ( + outputs["grits_con_f1"], + outputs["grits_con_precision"], + outputs["grits_con_recall"], + outputs["grits_con_upper_bound"], + ) = grits_con(true_text_grid, pred_text_grid) + + if "alignment" in metrics: + true_substructure, pred_substructure = factored_2dmss( + true_text_grid, pred_text_grid, reward_function=lcs_similarity, return_substructures=True + ) + outputs["alignment"] = compute_lcs_df(true_substructure, pred_substructure) + + if "upper_bound" not in metrics: + outputs.pop("grits_top_upper_bound", None) + outputs.pop("grits_con_upper_bound", None) + + return outputs + + +def compute_lcs_df(array1: np.ndarray, array2: np.ndarray) -> pd.DataFrame: + if array1.shape != array2.shape: + raise ValueError("Input arrays must have the same shape") + + lcs_df = pd.DataFrame(index=range(array1.shape[0]), columns=range(array1.shape[1])) + + for i in range(array1.shape[0]): + for j in range(array1.shape[1]): + lcs_df.iat[i, j] = lcs_string(array1[i, j], array2[i, j]) + + lcs_df.columns = lcs_df.iloc[0] + lcs_df = lcs_df.iloc[1:] + + return lcs_df diff --git a/argilla-v1/src/extralit_v1/metrics/utils.py b/argilla-v1/src/extralit_v1/metrics/utils.py new file mode 100644 index 000000000..f6de9046e --- /dev/null +++ b/argilla-v1/src/extralit_v1/metrics/utils.py @@ -0,0 +1,107 @@ +import logging +from typing import Tuple, List, Dict, Union, Literal + +import numpy as np +import pandas as pd +from natsort import index_natsorted + +_LOGGER = logging.getLogger(__name__) + + +def harmonize_columns(true_df: pd.DataFrame, pred_df: pd.DataFrame) -> Tuple[pd.DataFrame, pd.DataFrame]: + """ + Ensure the same column ordering by aligning the columns of the dataframes and converting all columns to the same dtypes. + Args: + true_df: + pred_df: + + Returns: + + """ + if not true_df.shape[1] or not pred_df.shape[1]: + return true_df, pred_df + + # Convert all columns to same dtypes + concat_iterables_fn = lambda x: ",".join(map(str, x)) if isinstance(x, (list, set, np.ndarray)) else x + true_df = true_df.map(concat_iterables_fn) + pred_df = pred_df.map(concat_iterables_fn) + pred_df = pred_df.astype(true_df.dtypes.drop(index=true_df.columns.difference(pred_df.columns)), errors="ignore") + + extra_columns = pred_df.columns.difference(true_df.columns) + pred_df = pred_df.reindex(columns=true_df.columns.intersection(pred_df.columns).append(extra_columns)) + + return true_df, pred_df + + +def is_numeric_dtype(series: pd.Series) -> bool: + return np.issubdtype(series.dtype, np.number) + + +def natural_sort_key_generator(s: pd.Series) -> np.ndarray: + fill_value = -999999 if is_numeric_dtype(s) else "zzz" + return np.argsort(index_natsorted(s.fillna(fill_value))) + + +def reorder_rows( + true_df: pd.DataFrame, pred_df: pd.DataFrame, index_columns: List[str] = None, verbose=False +) -> Tuple[pd.DataFrame, pd.DataFrame]: + """ + Ensure the same row orders by determining the columns for sorting + + Args: + true_df: + pred_df: + index_columns: default None. If provided, these columns will be prioritized for sorting before other columns in the dataframes. + verbose: + """ + sort_columns = [] + if index_columns is not None and len(index_columns): + index_columns = true_df.columns.intersection(pred_df.columns).intersection(index_columns) + if index_columns.size: + sort_columns.extend(most_similar_columns(pred_df[index_columns], true_df[index_columns])) + common_columns = most_similar_columns(pred_df, true_df) + + if common_columns: + sort_columns.extend(common_columns) + + if sort_columns: + true_df = true_df.sort_values(by=sort_columns, key=natural_sort_key_generator) + pred_df = pred_df.sort_values(by=sort_columns, key=natural_sort_key_generator) + + if verbose: + _LOGGER.info(f"sort_columns: {sort_columns}") + + return true_df, pred_df + + +def convert_metrics_to_df( + metrics: Dict[str, float], format: Literal["series", "dataframe"] = "series", n_levels=2 +) -> Union[pd.Series, pd.DataFrame]: + if format in {"series", "dataframe"}: + metrics = pd.Series(metrics) + metrics.index = metrics.index.str.split("_", n=n_levels, expand=True) + if format == "dataframe": + metrics = metrics.to_frame() + + return metrics + + +def most_similar_columns(pred_df: pd.DataFrame, true_df: pd.DataFrame) -> List[str]: + intersecting_columns = { + col: len(set(true_df[col]) & set(pred_df[col])) + for col in true_df.columns.intersection(pred_df.columns) + if true_df[col].nunique() > 1 and pred_df[col].nunique() > 1 + } + + # If dtypes of intersecting columns are different, give warning + for col in intersecting_columns: + if true_df[col].dtype != pred_df[col].dtype: + _LOGGER.warning(f"\nColumn {col} has different dtypes: {true_df[col].dtype} and {pred_df[col].dtype}") + + if not intersecting_columns: + return [] + + intersecting_columns = sorted(intersecting_columns.items(), key=lambda x: x[1], reverse=True) + intersecting_columns = [col for col, _ in intersecting_columns] + + return intersecting_columns diff --git a/argilla/src/extralit/pipeline/__init__.py b/argilla-v1/src/extralit_v1/pipeline/__init__.py similarity index 100% rename from argilla/src/extralit/pipeline/__init__.py rename to argilla-v1/src/extralit_v1/pipeline/__init__.py diff --git a/argilla/src/extralit/pipeline/export/__init__.py b/argilla-v1/src/extralit_v1/pipeline/export/__init__.py similarity index 100% rename from argilla/src/extralit/pipeline/export/__init__.py rename to argilla-v1/src/extralit_v1/pipeline/export/__init__.py diff --git a/argilla-v1/src/extralit_v1/pipeline/export/dataset.py b/argilla-v1/src/extralit_v1/pipeline/export/dataset.py new file mode 100644 index 000000000..a14a6d72e --- /dev/null +++ b/argilla-v1/src/extralit_v1/pipeline/export/dataset.py @@ -0,0 +1,262 @@ +from typing import Dict, List, Optional, Literal, Union, Any + +import argilla_v1 as rg +import pandas as pd +import pandera as pa +from argilla_v1 import SpanLabelOption + + +def create_papers_dataset( + schema: pa.DataFrameSchema, + papers: pd.DataFrame, + fields: List[rg.TextField] = None, + span_columns: Optional[Dict[str, Union[Dict[str, str], List[SpanLabelOption]]]] = None, + metadata_columns: Optional[List[str]] = None, + vectors_settings: List[rg.VectorSettings] = None, + **kwargs, +) -> rg.FeedbackDataset: + fields = fields or [] + questions = [] + metadata_properties = {} + assert not metadata_columns or papers.columns.intersection(metadata_columns).size == len( + metadata_columns + ), "Some column in `metadata_columns` not found in the papers dataframe" + assert not span_columns or papers.columns.intersection(span_columns).size == len( + span_columns + ), "Some column in `span_columns` not found in the papers dataframe" + + for index_name in schema.index.names if schema.index else []: + metadata_properties[index_name] = rg.TermsMetadataProperty( + name=index_name, title=index_name.capitalize(), visible_for_annotators=True + ) + + # Questions + for field_name, column in schema.columns.items(): + is_multiselect = any(check.name == "multiselect" for check in column.checks) + if column.dtype.type == bool: + question = rg.LabelQuestion( + name=field_name, + title=column.title or field_name, + description=column.description, + labels={"True": "YES", "False": "NO"}, + required=not column.nullable, + ) + + elif column.dtype.type == list: + labels = next((check.statistics["isin"] for check in column.checks if "isin" in check.statistics), None) + question = rg.MultiLabelQuestion( + name=field_name, + title=column.title or field_name, + description=column.description, + labels=labels, + required=not column.nullable, + ) + + elif is_multiselect: + labels = next((check.statistics["isin"] for check in column.checks if check.name == "multiselect"), None) + question = rg.MultiLabelQuestion( + name=field_name, + title=column.title or field_name, + description=column.description, + labels=labels, + required=not column.nullable, + ) + else: + question = rg.TextQuestion( + name=field_name, + title=column.title or field_name, + description=column.description, + required=not column.nullable, + use_markdown=True, + ) + + questions.append(question) + + for column_name, labels in (span_columns or {}).items(): + question = rg.SpanQuestion( + name=f"span_{column_name}", field=column_name, title=column_name.capitalize(), labels=labels, required=False + ) + questions.append(question) + + # Metadatas + metadata_columns = papers.columns.intersection(metadata_columns or []) + for column_name, dtype in papers.dtypes.get(metadata_columns, {}).items(): + if column_name in schema.columns or column_name == "file_path": + continue + try: + if dtype == bool: + metadata_prop = rg.TermsMetadataProperty( + name=column_name, title=column_name.capitalize(), visible_for_annotators=True + ) + elif dtype == float: + metadata_prop = rg.FloatMetadataProperty( + name=column_name, title=column_name.capitalize(), visible_for_annotators=True + ) + elif dtype == int: + metadata_prop = rg.IntegerMetadataProperty( + name=column_name, title=column_name.capitalize(), visible_for_annotators=True + ) + elif dtype == object: + metadata_prop = rg.TermsMetadataProperty( + name=column_name, title=column_name.capitalize(), visible_for_annotators=True + ) + else: + metadata_prop = rg.TermsMetadataProperty( + name=column_name, title=column_name.capitalize(), visible_for_annotators=True + ) + + metadata_properties[column_name] = metadata_prop + except Exception as e: + print(f"Failed to define metadata property {column_name} for the dataset: {e}") + + if not any(field.name == "metadata" for field in fields): + fields.insert(0, rg.TextField(name="metadata", title="Metadata", use_markdown=True)) + + return rg.FeedbackDataset( + fields=fields, + questions=questions, + metadata_properties=list(metadata_properties.values()) if metadata_properties else None, + guidelines=schema.description, + vectors_settings=vectors_settings, + **kwargs, + ) + + +def create_extraction_dataset( + fields: Optional[List[rg.TextField]] = None, + questions: Optional[List[rg.TextQuestion]] = None, + metadata_properties: Optional[List[rg.TermsMetadataProperty]] = None, + vectors_settings: Optional[List[rg.VectorSettings]] = None, +) -> rg.FeedbackDataset: + extraction_dataset = rg.FeedbackDataset( + guidelines="Manually validate every data entries in the data extraction sheet to build a " + "gold-standard validation dataset.", + fields=[ + rg.TextField(name="metadata", title="Reference:", required=True, use_markdown=True), + rg.TextField(name="extraction", title="Extracted data:", required=True, use_table=True), + rg.TextField(name="context", title="Top relevant segments:", required=False, use_markdown=True), + *(fields or []), + ], + questions=[ + rg.MultiLabelQuestion( + name="context-relevant", + title="Which of the document section(s) attributed to this data extraction table?", + description="Please identify which section in the source PDF the data extract came from, and select the matching section header(s) in this multi-selection list.", + type="dynamic_multi_label_selection", + labels=["Not listed"], + visible_labels=3, + required=False, + ), + rg.MultiLabelQuestion( + name="extraction-source", + title="Where did the extracted data primarily came from?", + labels=["Text", "Table", "Figure"], + required=False, + ), + rg.TextQuestion( + name="extraction-correction", + title="Provide a correction to the extracted data:", + required=True, + use_table=True, + ), + rg.TextQuestion( + name="notes", + title="Mention any notes for other extractors (or prompt for the LLM)", + required=False, + use_markdown=True, + ), + rg.TextQuestion( + name="issue", + title="Flag an issue for discrepancy between the Suggestion's extraction and your own extraction for Consensus Review", + description="If you are an extractor, please do not choose Approve, but you may choose Needs Review to flag an issue to discuss in a Consensus review. " + "If you are a reviewer, choose Approve to validate the extraction, or Needs redo extraction to let extractors know this record needs further work.", + required=False, + use_markdown=True, + ), + *(questions or []), + ], + vectors_settings=vectors_settings, + metadata_properties=[ + rg.TermsMetadataProperty(name="reference", title="Reference", visible_for_annotators=True), + rg.TermsMetadataProperty(name="type", title="Question Type", visible_for_annotators=True), + *(metadata_properties or []), + ], + ) + + return extraction_dataset + + +def create_preprocessing_dataset(): + dataset = rg.FeedbackDataset( + fields=[ + rg.TextField(name="metadata", title="Metadata:", required=True, use_markdown=True), + rg.TextField(name="header", title="Title:", required=True, use_markdown=True), + rg.TextField(name="image", title="Image:", required=False, use_markdown=True), + rg.TextField(name="text-1", title="Method 1:", required=False, use_markdown=True, use_table=False), + rg.TextField(name="text-2", title="Method 2:", required=False, use_markdown=True, use_table=False), + rg.TextField(name="text-3", title="Method 3:", required=False, use_markdown=True, use_table=False), + rg.TextField(name="text-4", title="Method 4:", required=False, use_markdown=True, use_table=False), + rg.TextField(name="text-5", title="Method 5:", required=False, use_markdown=True, use_table=False), + ], + questions=[ + rg.LabelQuestion( + name="ranking", + title="Which method extracted the most complete and accurate information?", + labels={ + "text-1": "Method 1", + "text-2": "Method 2", + "text-3": "Method 3", + "text-4": "Method 4", + "text-5": "Method 5", + "none": "None", + }, + type="dynamic_label_selection", + required=True, + ), + rg.MultiLabelQuestion( + name="mismatched", + title="Which of the method(s) extracted the wrong table/figure? (if any)", + description="This indication helps in evaluating these models accuracy", + type="dynamic_multi_label_selection", + labels={ + "text-1": "Method 1", + "text-2": "Method 2", + "text-3": "Method 3", + "text-4": "Method 4", + "text-5": "Method 5", + }, + required=False, + ), + rg.TextQuestion( + name="header-correction", + title="Correct the table or figure title:", + required=False, + use_markdown=False, + ), + rg.TextQuestion( + name="text-correction", + title="Correct the extracted data:", + required=False, + use_markdown=True, + use_table=True, + ), + rg.TextQuestion( + name="notes", + title="Mention any notes here", + required=False, + use_markdown=False, + ), + ], + metadata_properties=[ + rg.TermsMetadataProperty(name="reference", title="Reference"), + rg.TermsMetadataProperty(name="page_number", title="Page Number"), + rg.TermsMetadataProperty(name="number", title="Table/Figure Number"), + rg.TermsMetadataProperty(name="type", title="Element type"), + rg.TermsMetadataProperty(name="pmid", title="Document Pubmed ID"), + rg.TermsMetadataProperty(name="doc_id", title="Document ID", visible_for_annotators=False), + rg.FloatMetadataProperty(name="probability", title="Detection probability"), + rg.TermsMetadataProperty(name="annotators"), + ], + ) + + return dataset diff --git a/argilla-v1/src/extralit_v1/pipeline/export/record.py b/argilla-v1/src/extralit_v1/pipeline/export/record.py new file mode 100644 index 000000000..a225a9bac --- /dev/null +++ b/argilla-v1/src/extralit_v1/pipeline/export/record.py @@ -0,0 +1,248 @@ +import logging +import uuid +from typing import Dict, List, Optional + +import argilla_v1 as rg +import pandas as pd +import pandera as pa +from argilla.client.feedback.dataset.remote.dataset import RemoteFeedbackDataset +from llama_index.embeddings.openai import OpenAIEmbedding +from tqdm import tqdm + +from extralit_v1.convert.json_table import df_to_json +from extralit_v1.extraction.models.paper import PaperExtraction +from extralit_v1.extraction.models.response import ResponseResults + +_LOGGER = logging.getLogger(__name__) + + +def create_extraction_records( + paper_extractions: Dict[str, PaperExtraction], + papers: pd.DataFrame, + responses: Optional[Dict[str, ResponseResults]] = None, + dataset: RemoteFeedbackDataset = None, + metadata: Optional[Dict[str, str]] = None, +) -> List[rg.FeedbackRecord]: + """ + Push the extractions to the Argilla (Preprocessing) FeedbackDataset. + + Args: + paper_extractions: Dict[str, PaperExtraction], required + The extractions for each paper. + papers: pd.DataFrame, required + The papers dataframe. + responses: Dict[str, ResponseResults], optional + The responses for each paper. + dataset: RemoteFeedbackDataset, default=None + The Argilla dataset. + metadata: Dict[str,str], default=None + Additional metadata to add to the records. + + Returns: + List[rg.FeedbackRecord] + """ + assert ( + isinstance(dataset, RemoteFeedbackDataset) or dataset is None + ), f"dataset must be an instance of RemoteFeedbackDataset, given {type(dataset)}" + records = [] + for ref, extractions in paper_extractions.items(): + paper = papers.loc[[ref]].iloc[0] + + if dataset is not None: + if isinstance(paper.file_path, str): + doc = dataset.add_document( + rg.Document.from_file( + paper.file_path, reference=ref, pmid=paper.get("pmid"), doi=paper.get("doi"), id=paper.get("id") + ) + ) + else: + raise Exception(f"Unable to load document for {ref}") + else: + doc = rg.Document(file_name="/") + + ### metadata ### + metadata = metadata or {} + metadata["reference"] = ref + if doc.id: + metadata["doc_id"] = str(doc.id) + if isinstance(doc.pmid, str): + metadata["pmid"] = doc.pmid + if doc.doi: + metadata["doi"] = doc.doi + + schema_order = extractions.schemas.ordering + for schema_name, extraction in extractions.items(): + schema = extractions.schemas[schema_name] + if extraction is None or extraction.empty: + _LOGGER.warning(f"No {schema_name} extraction for {ref}, generating an empty table.") + extraction = generate_empty_extraction(schema, size=2) + + ### fields ### + fields = { + "extraction": df_to_json( + extraction, schema, drop_columns=["publication_ref", "Group"], metadata={"reference": ref} + ), + } + + ref_url = f"dataset/{dataset.id}/annotation-mode" + nav_df = pd.DataFrame( + [ + [ + f"[Step {i}]({ref_url}?_page={i}&_metadata=reference.{ref})" + if step != schema_name + else f"Step {i} (here)" + for i, step in enumerate(schema_order, start=1) + ] + ], + columns=schema_order, + index=pd.Index(["Navigate to"]), + ) + fields["metadata"] = f"Paper: {ref}\n" + nav_df.to_markdown(index=True) + + # Retrieve most relevant context + if responses and ref in responses and schema_name in responses[ref].items: + nodes_df = responses[ref].items[schema_name].get_nodes_info() + nodes_df.drop( + columns=nodes_df.columns.difference(["relevance", "header", "page_number", "text"]), + errors="ignore", + inplace=True, + ) + # nodes_df['page_number'] = nodes_df['page_number'].map(lambda x: f"[Page {x}](#page_number.{x})" if x else None) + fields["context"] = nodes_df.style.background_gradient( + axis=1, subset=["relevance"], cmap="RdYlGn" + ).to_html(index=False, na_rep="") + + # ### suggestions ### + # suggestions = [ + # { + # "question_name": "context-relevant", + # "value": headers, + # "type": "selection", + # }, + # ] + + record = rg.FeedbackRecord( + fields=fields, + # suggestions=suggestions if len(headers) else [], + metadata={ + **metadata, + "type": schema_name, + }, + ) + records.append(record) + + return records + + +def create_publication_records( + papers: pd.DataFrame, + schema: pa.DataFrameSchema, + dataset: RemoteFeedbackDataset, + embed_model="text-embedding-3-large", +) -> List[rg.FeedbackRecord]: + """ + Push the publications to the Argilla (Preprocessing) FeedbackDataset. + """ + assert ( + papers.index.name == "reference" + ), f"The given dataframe must have index name as 'reference', given {papers.index.name}" + records = [] + question_names = [q.name for q in dataset.questions] + + embed_models = {} + if embed_model: + for vectors_setting in dataset.vectors_settings: + embed_models[vectors_setting.name] = OpenAIEmbedding( + model=embed_model, dimensions=vectors_setting.dimensions or 1024 + ) + + for reference, paper in tqdm(papers.iterrows()): + if dataset is not None: + assert isinstance(dataset, RemoteFeedbackDataset) + if isinstance(paper.file_path, str): + doc = dataset.add_document( + rg.Document.from_file( + paper.file_path, + reference=str(reference), + pmid=paper.get("pmid"), + doi=paper.get("doi"), + id=paper.get("id", uuid.uuid4()), + ) + ) + else: + raise Exception(f"Unable to load document for {paper.name} from {paper.file_path}") + else: + doc = rg.Document(file_name="/") + + metadata = { + "reference": paper.name, + **({"doc_id": str(doc.id)} if doc.id is not None else {}), + } + + dataset_field_names = [f.name for f in dataset.fields] + publication_metadata = {k: v for k, v in paper.to_dict().items() if k in dataset_field_names} + metadata.update(publication_metadata) + publication_metadata = pd.Series(publication_metadata, name=paper.name) + + fields = { + **{k: v for k, v in paper.to_dict().items() if k in dataset_field_names and not pd.isna(v)}, + } + if "metadata" in dataset_field_names: + fields["metadata"] = publication_metadata.to_frame().to_html(index=True) + + vectors = { + name: model.get_text_embedding(fields[name]) + for name, model in embed_models.items() + if name in fields and fields[name].strip() + } + + # Create suggestions + agent = None + suggestions = [] + for field in schema.columns: + if field in question_names and field in paper and paper[field] is not None: + suggestions.append( + { + "question_name": field.lower(), + "value": str(paper[field]), + "type": "human", + "agent": agent, + } + ) + + record = rg.FeedbackRecord( + fields=fields, + metadata={k: v for k, v in metadata.items() if not pd.isna(v)}, + suggestions=suggestions, + vectors=vectors, + ) + records.append(record) + + return records + + +def generate_empty_extraction(schema: pa.DataFrameSchema, size=2) -> pd.DataFrame: + default_value = ["NA"] * size + df = pd.DataFrame.from_dict({col: default_value for i, col in enumerate(schema.columns) if i < 5}) + + if isinstance(schema.index, pa.MultiIndex): + index_names = [] + index_prefixes = [] + for index in schema.index.indexes: + index_names.append(index.name) + str_startswith_check = next(check for check in index.checks if check.name == "str_startswith") + index_prefixes.append(str_startswith_check.statistics["string"]) + index = pd.MultiIndex.from_tuples( + [tuple(f"{prefix}{i+1}" for prefix in index_prefixes) for i in range(size)], names=index_names + ) + elif schema.index: + str_startswith_check = next(check for check in schema.index.checks if check.name == "str_startswith") + prefix = str_startswith_check.statistics["string"] + index = pd.Index([f"{prefix}{i+1}" for i in range(size)], name=schema.index.name if schema.index else None) + else: + index = None + + if index is not None: + df = df.set_index(index) + + return df diff --git a/argilla/src/extralit/pipeline/ingest/__init__.py b/argilla-v1/src/extralit_v1/pipeline/ingest/__init__.py similarity index 100% rename from argilla/src/extralit/pipeline/ingest/__init__.py rename to argilla-v1/src/extralit_v1/pipeline/ingest/__init__.py diff --git a/argilla-v1/src/extralit_v1/pipeline/ingest/paper.py b/argilla-v1/src/extralit_v1/pipeline/ingest/paper.py new file mode 100644 index 000000000..d08b4caf4 --- /dev/null +++ b/argilla-v1/src/extralit_v1/pipeline/ingest/paper.py @@ -0,0 +1,135 @@ +from typing import List, Optional +from collections import defaultdict + +import argilla_v1 as rg +import pandas as pd +from argilla.client.feedback.schemas.remote.records import RemoteFeedbackRecord + +from extralit_v1.convert.json_table import json_to_df, is_json_table +from extralit_v1.extraction.models.paper import PaperExtraction +from extralit_v1.extraction.models.schema import SchemaStructure +from extralit_v1.pipeline.ingest.record import get_record_data + + +def get_paper_extraction_status( + references: List[str], + schemas: SchemaStructure, + paper_dataset: rg.FeedbackDataset, + extraction_dataset: rg.FeedbackDataset = None, + preprocessing_dataset: rg.FeedbackDataset = None, +) -> pd.DataFrame: + assert schemas.singleton_schema is not None, "Document schema must be given in the schemas." + users = rg.Workspace.from_name(paper_dataset.workspace.name).users + users_id_to_username = {u.id: u.username for u in users} + + paper_records: List[RemoteFeedbackRecord] = paper_dataset.filter_by( + metadata_filters=rg.TermsMetadataFilter(name="reference", values=references) + ).records + + document_schema = schemas.singleton_schema + references_data = [] + for record in paper_records: + reference = record.metadata["reference"] + metadata = record.metadata + values = get_record_data( + record, + answers=document_schema.columns, + suggestions=document_schema.columns, + status=["submitted"], + include_user_id=True, + ) + values["reference"] = reference + values["checked_out"] = users_id_to_username.get(values.pop("user_id", "NA"), "NA") + user_statuses = { + users_id_to_username.get(response.user_id, "NA"): response.status.name for response in record.responses + } + values[document_schema.name] = user_statuses + metadata.update(values) + references_data.append(metadata) + + references_df = pd.DataFrame(references_data).set_index("reference") + + extraction_records: List[RemoteFeedbackRecord] = extraction_dataset.filter_by( + metadata_filters=rg.TermsMetadataFilter(name="reference", values=references) + ).records + + extraction_schemas = schemas.schemas + extraction_data = defaultdict(dict) + for record in extraction_records: + schema_name = record.metadata["type"] + reference = record.metadata["reference"] + user_statuses = { + users_id_to_username.get(response.user_id, "NA"): response.status.name for response in record.responses + } + extraction_data[reference][schema_name] = user_statuses + extraction_df = pd.DataFrame.from_dict( + extraction_data, + orient="index", + ) + extraction_df.index.name = "reference" + + extraction_status = references_df.join(extraction_df, on="reference") + return extraction_status + + +def get_paper_extractions( + paper: pd.Series, + dataset: rg.FeedbackDataset, + schemas: SchemaStructure, + answer: str, + field: Optional[str] = None, + suggestion: Optional[str] = None, + users: Optional[List[rg.User]] = None, + statuses=["submitted"], +) -> PaperExtraction: + reference = paper.name + records: List[RemoteFeedbackRecord] = dataset.filter_by( + metadata_filters=rg.TermsMetadataFilter(name="reference", values=[reference]) + ).records + + extractions = {} + durations = {} + updated_at = {} + inserted_at = {} + user_id = {} + + for record in records: + if record.metadata["reference"] != reference: + continue + + outputs = get_record_data( + record, + fields=field, + answers=[answer, "duration"] if answer else ["duration"], + suggestions=[suggestion] if suggestion else [], + users=users, + include_user_id=True, + status=statuses, + ) + + if suggestion in outputs: + table_json = outputs[suggestion] + elif answer in outputs and is_json_table(outputs[answer]): + table_json = outputs[answer] + elif field in outputs and is_json_table(outputs[field]): + table_json = outputs[field] + else: + table_json = None + + for schema in schemas.schemas: + if schema.name == record.metadata["type"]: + extractions[schema.name] = json_to_df(table_json, schema=schema) + durations[schema.name] = outputs.get("duration", None) + updated_at[schema.name] = max([res.updated_at for res in record.responses], default=record.updated_at) + inserted_at[schema.name] = record.inserted_at + user_id[schema.name] = outputs.get("user_id", None) + + return PaperExtraction( + reference=reference, + extractions=extractions, + schemas=schemas, + durations=durations, + updated_at=updated_at, + inserted_at=inserted_at, + user_id=user_id, + ) diff --git a/argilla-v1/src/extralit_v1/pipeline/ingest/record.py b/argilla-v1/src/extralit_v1/pipeline/ingest/record.py new file mode 100644 index 000000000..ef38a60f0 --- /dev/null +++ b/argilla-v1/src/extralit_v1/pipeline/ingest/record.py @@ -0,0 +1,85 @@ +from datetime import datetime +from typing import Optional, Union, List, Dict, Any, Literal + +import argilla_v1 as rg +from argilla.client.feedback.schemas.remote.records import RemoteFeedbackRecord +from argilla.client.sdk.users.models import UserModel + +from extralit_v1.convert.json_table import is_json_table +from extralit_v1.pipeline.update.suggestion import get_record_suggestion_value + + +def get_record_data( + record: Union[RemoteFeedbackRecord, rg.FeedbackRecord], + fields: Optional[Union[List[str], str]] = None, + answers: Optional[Union[List[str], str]] = None, + suggestions: Optional[Union[List[str], str]] = None, + metadatas: Optional[Union[List[str], str]] = None, + users: Optional[Union[List[rg.User], rg.User]] = None, + include_user_id: bool = False, + status: Optional[List[Literal["submitted", "draft", "pending", "discarded"]]] = ["submitted", "draft"], +) -> Dict[str, Any]: + """ + Extracts data from a feedback record based on the specified parameters. + + Args: + record (Union[RemoteFeedbackRecord, rg.FeedbackRecord]): The feedback record to extract data from. + fields (Union[List[str], str]): The fields to extract from the record. + answers (Optional[Union[List[str], str]]): The answers to extract from the record's responses. + suggestions (Optional[Union[List[str], str]]): The suggestions to extract from the record's responses. + metadatas (Optional[Union[List[str], str]]): The metadata keys to extract from the record. + users (Optional[Union[List[rg.User], rg.User]]): The users whose responses should be considered. + include_user_id (bool, optional): Whether to include the user ID in the output. Defaults to False. + include_consensus (Optional[str], optional): If not None, then include the concensus status with the provided key as the argument. Defaults to None. + status (Optional[List[str]], optional): The statuses to filter. Defaults to ["submited"]. + + Returns: + Dict[str, Any]: A dictionary containing the extracted data. + + """ + fields = [fields] if isinstance(fields, str) else set(fields) if fields else [] + answers = [answers] if isinstance(answers, str) else set(answers) if answers else [] + suggestions = [suggestions] if isinstance(suggestions, str) else set(suggestions) if suggestions else [] + metadatas = [metadatas] if isinstance(metadatas, str) else set(metadatas) if metadatas else [] + users = [users] if isinstance(users, (UserModel, rg.User)) else list(users) if users else [] + responses = record.responses + + if users: + user_ids = {u.id for u in users} + responses = [r for r in responses if r.user_id in user_ids] + + if status: + responses = [r for r in responses if r.status.value in status] + data = {} + for field in fields: + if field in record.fields: + data[field] = record.fields[field] + + for suggestion in suggestions: + data[suggestion] = get_record_suggestion_value(record, question_name=suggestion, users=users) + + selected_response = next((r for r in responses[::-1] if r.values), None) + for answer in answers: + if selected_response and answer in selected_response.values: + data[answer] = selected_response.values[answer].value + + if include_user_id: + data["user_id"] = selected_response.user_id + + for key in metadatas: + if key in record.metadata and key not in data: + data[key] = record.metadata[key] + + return data + + +def get_record_timestamp(record: Union[RemoteFeedbackRecord, rg.FeedbackRecord]) -> Optional[datetime]: + timestamp = record.updated_at or record.inserted_at + + if len(record.responses): + response = record.responses[-1] + response_timestamp = response.updated_at or response.inserted_at + if response_timestamp and response_timestamp > timestamp: + timestamp = response_timestamp + + return timestamp diff --git a/argilla-v1/src/extralit_v1/pipeline/ingest/segment.py b/argilla-v1/src/extralit_v1/pipeline/ingest/segment.py new file mode 100644 index 000000000..040022cfa --- /dev/null +++ b/argilla-v1/src/extralit_v1/pipeline/ingest/segment.py @@ -0,0 +1,92 @@ +from typing import List, Literal + +import argilla_v1 as rg +import pandas as pd +from argilla_v1 import FeedbackRecord + +from extralit_v1.pipeline.ingest.record import get_record_data +from extralit_v1.preprocessing.segment import Segments, FigureSegment, TableSegment + + +def get_paper_tables( + paper: pd.Series, + dataset: rg.FeedbackDataset, + select: str = "text-correction", + response_status: List[Literal["discarded", "submitted", "pending", "draft"]] = ["submitted"], +) -> Segments: + """ + Get the tables manually annotated a given paper in an Argilla (Preprocessing) FeedbackDataset. + + Args: + paper: pd.Series, required + A paper from the dataset. + dataset: rg.FeedbackDataset, required + The Argilla (Preprocessing) FeedbackDataset. + select: str, default='text-correction' + The field to select from the dataset records. + response_status: List[str], default=['discarded'] + + Returns: + Segments: The tables manually annotated for the given paper. + """ + records: List[FeedbackRecord] = dataset.filter_by( + metadata_filters=rg.TermsMetadataFilter(name="reference", values=[paper.name]), response_status=response_status + ).records + + segments = Segments() + for record in records: + values = get_record_data( + record, + fields=["text-1", "text-2", "text-3", "text-4", "text-5", "header"], + answers=["text-correction", "header-correction", "footer-correction", "ranking", "duration"], + metadatas=["page_number", "number"], + status=response_status, + ) + if "ranking" not in values or values["ranking"] == "none": + continue + + try: + if select.strip().lower() == "ranking" and values["ranking"] in values: + html = values[values["ranking"]] + + elif select in values and "correction" in select and not values[select]: + # Skip the empty corrections + html = values[values["ranking"]] + + elif select in values: + html = values[select] + else: + continue + except: + continue + + if "header-correction" in values: + header = values["header-correction"] + else: + header = values["header"] + + type = values.get("number", "table").split(" ")[0].lower() + if type == "figure": + segment = FigureSegment( + id=str(record.id), + header=header.strip(), + footer=values.get("footer-correction", None), + page_number=values.get("page_number", None), + text=html, + html=html, + duration=values.get("duration", None), + ) + else: + segment = TableSegment( + id=str(record.id), + header=header.strip(), + footer=values.get("footer-correction", None), + page_number=values.get("page_number", None), + text=html, + html=html, + duration=values.get("duration", None), + ) + + segments.items.append(segment) + + return segments diff --git a/argilla-v1/src/extralit_v1/pipeline/ingest/trace.py b/argilla-v1/src/extralit_v1/pipeline/ingest/trace.py new file mode 100644 index 000000000..a27d890e5 --- /dev/null +++ b/argilla-v1/src/extralit_v1/pipeline/ingest/trace.py @@ -0,0 +1,46 @@ +import re +from typing import Dict, List, Iterator, Any + +from langfuse import Langfuse +from langfuse.api import ObservationsView + +from extralit_v1.extraction.models import PaperExtraction +from extralit_v1.metrics.extraction import grits_multi_tables +from extralit_v1.server.context.llamaindex import get_langfuse_callback + +langfuse_callback = get_langfuse_callback() +langfuse_client = langfuse_callback.langfuse + + +def get_langfuse_traces( + langfuse_client: Langfuse, + user_id: str, + references: List[str] = None, + schema_names: List[str] = None, + input_match="Please complete the following", + page_size=50, +) -> Iterator[ObservationsView]: + page = 1 + max_page = None + while max_page is None or page <= max_page: + traces_batch = langfuse_client.get_observations(user_id=user_id, limit=page_size, page=page) + max_page = traces_batch.meta.total_pages + + for trace in traces_batch.data: + if trace is None or not trace.metadata: + continue + trace_metadata: dict = trace.metadata.get("metadata", {}) or {} + if references and not next( + (metadata.get("reference") in references for metadata in trace_metadata.values()), None + ): + continue + if input_match and not re.search(input_match, trace.input): + continue + # if schema_names and not next((metadata.get('schema') in schema_names for metadata in trace.metadata.values()), None): + # continue + if not isinstance(trace.output, dict) or not trace.output.get("items", []): + continue + + yield trace + + page += 1 diff --git a/argilla/src/extralit/pipeline/update/__init__.py b/argilla-v1/src/extralit_v1/pipeline/update/__init__.py similarity index 100% rename from argilla/src/extralit/pipeline/update/__init__.py rename to argilla-v1/src/extralit_v1/pipeline/update/__init__.py diff --git a/argilla-v1/src/extralit_v1/pipeline/update/monitor_records.py b/argilla-v1/src/extralit_v1/pipeline/update/monitor_records.py new file mode 100644 index 000000000..39276b5a5 --- /dev/null +++ b/argilla-v1/src/extralit_v1/pipeline/update/monitor_records.py @@ -0,0 +1,152 @@ +import logging +from collections import defaultdict +from time import sleep +from typing import Dict, Optional, Any + +import argilla_v1 as rg +from argilla.client.feedback.dataset.remote.dataset import RemoteFeedbackDataset +from argilla.client.feedback.schemas.remote.records import RemoteFeedbackRecord + +from extralit_v1.convert.json_table import json_to_df, df_to_json +from extralit_v1.pipeline.ingest.record import get_record_table, get_record_timestamp + + +class MonitorIntegrationDataset: + _logger = logging.getLogger("MonitorIntegrationDataset") + _logger.setLevel(logging.INFO) + + def __init__(self, from_dataset: RemoteFeedbackDataset, to_dataset: RemoteFeedbackDataset) -> None: + """ + + + Args: + from_dataset (rg.FeedbackDataset): The dataset to monitor. + to_dataset (rg.FeedbackDataset): The dataset to update. + """ + self.from_dataset = from_dataset + self.to_dataset = to_dataset + self._records = None + self.fetch_records() + + def fetch_records(self) -> None: + # Caches the records in a dict with the reference as key for faster access + self._records = {record.metadata["reference"]: record for record in self.to_dataset.records} + + @property + def records(self) -> Dict[str, RemoteFeedbackRecord]: + return self._records + + def create_record(self, reference: str, metadata: Dict[str, Any]) -> rg.FeedbackRecord: + if "type" in metadata: + del metadata["type"] # Remove the type from the metadata + record = rg.FeedbackRecord( + fields={ + "metadata": reference, + }, + metadata=metadata, + ) + return record + + def monitor(self, batch_size: int = 5, infinite: bool = True) -> None: + """ + Monitor the dataset by processing it in batches. + + Args: + batch_size (int): Size of each batch. + """ + extractions_batch = defaultdict(lambda: {}) + last_batch_size = 0 + + while True: + for record in self.from_dataset: + reference = record.metadata["reference"] + + extraction = get_record_table(record, field="extraction", answer="extraction-correction") + updated_at = get_record_timestamp(record) + + # Update timestamp if newer + if ( + "updated_at" not in extractions_batch[reference] + or not extractions_batch[reference]["updated_at"] + or (updated_at and updated_at > extractions_batch[reference]["updated_at"]) + ): + extractions_batch[reference]["updated_at"] = updated_at + + # Add extraction to batch + if extraction: + question_type = record.metadata["type"] + extractions_batch[reference][question_type] = extraction + extractions_batch[reference]["metadata"] = record.metadata + + # Check if the batch should be processed + if len(extractions_batch) >= batch_size or ( + len(extractions_batch) > 0 and len(extractions_batch) == last_batch_size + ): + extractions_batch = self.process_batch(extractions_batch) + last_batch_size = 0 # Reset the last batch size after processing + else: + last_batch_size = len(extractions_batch) # Update last batch size + + if not infinite: + break + sleep(2) + + if extractions_batch: + self.process_batch(extractions_batch) + + def process_batch(self, extractions_batch: Dict[str, Dict[str, Any]]) -> Dict[str, Dict[str, Optional[str]]]: + """ + Add/Update the records to the `to_dataset` with the extractions in the batch. + + Args: + extractions_batch (Dict[str, Dict[str, Any]]): A dict with the reference as key and a dict with the + `from_dataset` records as value. + """ + + update_records = {} + add_records = {} + for reference, extractions in extractions_batch.items(): + if not extractions["updated_at"]: + continue + + is_new_record = reference not in self.records + + if is_new_record: + record = rg.FeedbackRecord( + fields={ + "metadata": f"[{reference}](dataset/{self.from_dataset.id}/annotation-mode?_page=1&_status=valid&_metadata=reference.{reference})", + }, + metadata={k: v for k, v in extractions["metadata"].items() if k != "type"}, + ) + else: + # Check if record need to be updated + record = self.records[reference] + if extractions["updated_at"] and extractions["updated_at"] <= get_record_timestamp(record): + continue + + if is_new_record: + add_records[reference] = record + else: + update_records[reference] = record + + sep = "\n\t" + if update_records: + self._logger.info( + f"Updating {len(update_records)} records: \n" + f"{sep.join([rec.metadata['reference'] + ' ' + str(list(rec.fields.keys())) for rec in update_records.values()])} \n" + ) + for record in update_records.values(): + record.updated_at = extractions_batch[record.metadata["reference"]]["updated_at"] + self.to_dataset.update_records(list(update_records.values()), show_progress=False) + + if add_records: + self._logger.info( + f"Adding {len(add_records)} records: \n" + f"{sep.join([rec.metadata['reference'] + ' ' + str(list(rec.fields.keys())) for rec in add_records.values()])} \n" + ) + self.to_dataset.add_records(list(add_records.values()), show_progress=False) + self.fetch_records() + + # reset batch to only last reference + new_batch = defaultdict(lambda: {}) + return new_batch diff --git a/argilla-v1/src/extralit_v1/pipeline/update/monitor_suggestion.py b/argilla-v1/src/extralit_v1/pipeline/update/monitor_suggestion.py new file mode 100644 index 000000000..2134fff1a --- /dev/null +++ b/argilla-v1/src/extralit_v1/pipeline/update/monitor_suggestion.py @@ -0,0 +1,146 @@ +import json +import logging +import re +from typing import List + +import argilla_v1 as rg +import numpy as np +import pandas as pd +from bs4 import BeautifulSoup +from setfit import SetFitModel + + +class Monitor: + _logger = logging.getLogger("Monitor") + _logger.setLevel(logging.INFO) + + def __init__(self, dataset: rg.FeedbackDataset, question: str) -> None: + """ + Initialize the Monitor. + + Args: + dataset (rg.FeedbackDataset): The dataset to monitor. + itnrecal (SetFitModel): The NLP itnrecal for predictions. + question (str): The specific question for monitoring. + """ + self.dataset = dataset + self.question = question + + @staticmethod + def _strip_html(text_with_html: str) -> str: + """ + Remove HTML tags from the given text. + + Args: + text_with_html (str): Text containing HTML tags. + + Returns: + str: Text without HTML tags. + """ + soup = BeautifulSoup(text_with_html, "html.parser") + plain_text = soup.get_text() + stripped_text = re.sub(r"\s+", " ", plain_text) # remove duplicate whitespaces + return stripped_text.strip() + + def monitor(self, batch_size: int = 50, infinite: bool = True) -> None: + """ + Monitor the dataset by processing it in batches. + + Args: + batch_size (int): Size of each batch. + """ + record_batch = [] + while True: + for rec in self.dataset: + questions = [sug.question_name for sug in rec.suggestions] + if self.question not in questions: + record_batch.append(rec) + else: + record_batch.append(rec) + if len(record_batch) == batch_size: + self.update_batch(record_batch) + record_batch = [] + if record_batch: + self.update_batch(record_batch) + record_batch = [] + if not infinite: + break + + def update_batch(self, batch: List[rg.FeedbackRecord]) -> None: + """ + Update the itnrecal predictions for a batch of records. + + Args: + batch (List[rg.FeedbackRecord]): Batch of records. + """ + self._logger.info(f"Updating batch of {len(batch)} records") + + texts, keys = self.get_texts_and_keys(batch) + print(texts) + self.add_suggestions_to_records(batch) + + def get_texts_and_keys(self, batch: List[rg.FeedbackRecord]) -> tuple: + """ + Extract texts and keys from a batch of records. + + Args: + batch (List[rg.FeedbackRecord]): Batch of records. + + Returns: + tuple: Texts and keys. + """ + self._logger.info("Formatting texts and keys") + texts = [] + keys = [] + for rec in batch: + print(rec.fields) + key = list(dict(json.loads(rec.fields["header"])).keys())[0] + keys.append(key) + texts.append(rec.fields["text-1"]) + return texts, keys + + def merge_predictions(self, preds: List[List], ids: List[int]) -> List[np.ndarray]: + """ + Merge predictions and aggregate them. + + Args: + preds (List[List]): List of prediction lists. + ids (List[int]): List of corresponding IDs. + + Returns: + List[np.ndarray]: List of merged prediction arrays. + """ + self._logger.info("Merging predictions") + df = pd.DataFrame(columns=["ids", "preds"]) + df["ids"] = ids + df["preds"] = [pred.tolist() for pred in preds] + df = df.groupby("ids").agg({"preds": list}).reset_index() + preds = df["preds"].tolist() + mean_preds = [np.mean(pred, axis=0) for pred in preds] + return mean_preds + + def add_suggestions_to_records( + self, + batch: List[rg.FeedbackRecord], + ) -> None: + """ + Add itnrecal predictions to the records in the batch. + + Args: + batch (List[rg.FeedbackRecord]): Batch of records. + preds (List[np.ndarray]): List of itnrecal predictions. + """ + self._logger.info("Adding suggestions to records") + updated_records = [] + for rec in zip(batch): + suggestions_schema = rg.SuggestionSchema( + question_name=self.question, + # score=pred[0] if pred[0] > 0.5 else pred[1], + # value="yes" if pred[0] > 0.5 else "no", + ) + updated_suggestions = [suggestions_schema] + for sug in rec.suggestions: + if sug.question_name != self.question: + updated_suggestions.append(sug) + updated_records.append(rec) + # self.dataset.update_records(updated_records) diff --git a/argilla-v1/src/extralit_v1/pipeline/update/schema.py b/argilla-v1/src/extralit_v1/pipeline/update/schema.py new file mode 100644 index 000000000..f746d7086 --- /dev/null +++ b/argilla-v1/src/extralit_v1/pipeline/update/schema.py @@ -0,0 +1,52 @@ +import json +import logging +from typing import Optional, Dict + +import argilla_v1 as rg +import pandera as pa + +__all__ = ["update_table_schema", "update_record_table_schema"] + + +def update_table_schema(table_json_str: str, schema_json: Dict, reference: str, schema_name: str) -> str: + table_json = json.loads(table_json_str) + + if reference: + table_json["reference"] = reference + + table_json["validation"] = schema_json + + return json.dumps(table_json) + + +def update_record_table_schema( + record: rg.FeedbackRecord, schema: pa.DataFrameSchema, field: str, answer: Optional[str] = None +) -> rg.FeedbackRecord: + reference = record.metadata.get("reference") + schema_name = record.metadata["type"] + assert schema_name == schema.name, f"Schema name `{schema_name}` does not match the schema provided." + + schema_json = json.loads(schema.to_json()) + + try: + record.fields[field] = update_table_schema( + record.fields[field], schema_json=schema_json, reference=reference, schema_name=schema_name + ) + except Exception as e: + logging.error( + f'Unable to update {schema_name} schema for field `{field}` or answer `{answer}` in {record.metadata["reference"]}. \n{e}' + ) + raise e + + ### Update tables in responses (doesn't yet update in argilla backend) + # for i, response in enumerate(record.responses): + # if answer not in response.values: + # continue + # elif not is_json_table(response.values[answer].value): + # continue + # + # record.responses[i].values[answer].value = update_table_schema( + # response.values[answer].value, schema_json=schema_json, reference=reference, + # schema_name=schema_name) + + return record diff --git a/argilla-v1/src/extralit_v1/pipeline/update/suggestion.py b/argilla-v1/src/extralit_v1/pipeline/update/suggestion.py new file mode 100644 index 000000000..9bea7ab98 --- /dev/null +++ b/argilla-v1/src/extralit_v1/pipeline/update/suggestion.py @@ -0,0 +1,40 @@ +from typing import Union, List, Optional + +import argilla_v1 as rg + +__all__ = ["update_record_suggestions"] + +from argilla.client.feedback.schemas.remote.records import RemoteFeedbackRecord + + +def update_record_suggestions( + record: RemoteFeedbackRecord, suggestions: Union[rg.SuggestionSchema, List[rg.SuggestionSchema]] +) -> rg.FeedbackRecord: + if not isinstance(suggestions, list): + suggestions = [suggestions] + + # Create a dictionary from the new suggestions + new_suggestions_dict = { + (s.question_name, s.type, s.agent): s for s in suggestions if s.question_name in record.question_name_to_id + } + + if new_suggestions_dict: + # Keep only the suggestions that are not in the new suggestions + updated_suggestions = [ + s for s in record.suggestions if (s.question_name, s.type, s.agent) not in new_suggestions_dict + ] + + record.suggestions = updated_suggestions + list(new_suggestions_dict.values()) + + return record + + +def get_record_suggestion_value( + record: RemoteFeedbackRecord, question_name: str, users: List[rg.User] = None +) -> Optional[str]: + usernames = {user.username for user in users} if users else None + for suggestion in record.suggestions: + if suggestion.question_name == question_name and (not usernames or suggestion.agent in usernames): + return suggestion.value + + return None diff --git a/argilla/src/extralit/preprocessing/__init__.py b/argilla-v1/src/extralit_v1/preprocessing/__init__.py similarity index 100% rename from argilla/src/extralit/preprocessing/__init__.py rename to argilla-v1/src/extralit_v1/preprocessing/__init__.py diff --git a/argilla-v1/src/extralit_v1/preprocessing/alignment.py b/argilla-v1/src/extralit_v1/preprocessing/alignment.py new file mode 100644 index 000000000..24ac885a1 --- /dev/null +++ b/argilla-v1/src/extralit_v1/preprocessing/alignment.py @@ -0,0 +1,321 @@ +import copy +import difflib +import os +from collections import Counter +from typing import List, Optional, Tuple, Union, Dict, Any + +import argilla_v1 as rg +import pandas as pd +from argilla.client.feedback.utils import image_to_html +from pydantic.v1 import BaseModel, Field, validator +from rapidfuzz import fuzz +from unstructured.documents.elements import Element, Header, FigureCaption, Image, Footer, Table as UnstructuredTable + +from extralit_v1.convert.text import find_longest_superstrings +from extralit_v1.preprocessing.segment import TextSegment, TableSegment, CHUNK_DELIM, FigureSegment, Segments +from extralit_v1.preprocessing.tables import SAMPLE_HTML_TABLE + + +class Alignments(BaseModel): + items: List["SegmentsAlignment"] = Field(default_factory=list, description="List of SegmentsAlignment objects") + + def to_records(self, dataset: rg.FeedbackDataset, fill_missing_tables=False, **kwargs) -> List[rg.FeedbackRecord]: + records = [item.to_record(dataset=dataset, **kwargs) for item in self.items] + records = [record for record in records if record is not None] + + if fill_missing_tables: + self.insert_missing_tables(records, **kwargs) + + return records + + def insert_missing_tables(self, records: List[rg.FeedbackRecord], **kwargs): + captured_numbers = [item.number for item in self.items if item.number] + numbers_count = Counter(captured_numbers) + max_number = max([item.number for item in self.items if item.number], default=0) + + for number in range(1, max_number + 1): + if number in captured_numbers: + continue + # Check that the previous number is not duplicated + if number > 1 and numbers_count.get(number - 1, 1e6) > 1: + continue + + metadata = copy.deepcopy(kwargs.get("metadata", {})) + metadata["number"] = f"Table {number}" + fields = { + "header": f"This table #{number} was not detected", + "metadata": pd.DataFrame.from_dict(metadata, orient="index").T.to_markdown(index=False), + } + + suggestions = copy.deepcopy(kwargs.get("suggestions", [])) + suggestions.extend( + [ + {"question_name": "header-correction", "value": f"Table {number}"}, + {"question_name": "text-correction", "value": SAMPLE_HTML_TABLE}, + ] + ) + missing_record = rg.FeedbackRecord(fields=fields, suggestions=suggestions, metadata=metadata) + records.insert(number - 1, missing_record) + + def __repr_str__(self, join_str: str) -> str: + return f"\n " + f"{join_str}\n ".join(f"{type(item).__name__}({item})" for item in self.items) + + def __getitem__(self, index): + return self.items[index] + + def __len__(self): + return len(self.items) + + +class SegmentsAlignment(BaseModel): + header: str = Field(..., description="Header of the element", example="Abstract") + type: str = Field(..., description="Type of the element", example="text") + page_number: int = Field(..., example=1) + summary: Optional[str] = Field(..., description="Summary of the content", example="Summary") + number: Optional[int] = Field(None, description="Number of the table/figure", example=1) + extractions: Dict[str, Any] = Field(default_factory=dict, description="Extractions from different sources") + image: Optional[str] = Field(None) + probability: Optional[float] = Field(None, description="Probability of the detection algorithm") + + @validator("extractions", each_item=True, pre=True) + def check_segment_types(cls, segment: Union[FigureSegment, TableSegment]): + return segment + + def __getitem__(self, key): + return self.extractions[key] + + def __repr_str__(self, join_str: str) -> str: + extractions_str = f"{join_str}\n\t" + f"{join_str}\n\t".join( + f'"{k}"={type(v).__name__}({v})' for k, v in self.extractions.items() + ) + return f"page_number={self.page_number}, number={self.number}, extractions={{ {extractions_str} }}" + + def to_record( + self, + dataset: rg.FeedbackDataset, + fields: Optional[Dict[str, str]] = None, + suggestions: Optional[List[Dict[str, str]]] = None, + metadata: Optional[Dict[str, str]] = None, + **kwargs, + ) -> Optional[rg.FeedbackRecord]: + fields: Dict[str, str] = {**(fields or {})} + if isinstance(self.header, str): + fields["header"] = self.header.strip() + else: + fields["header"] = "" + + if self.summary: + fields["header"] += f"{CHUNK_DELIM}{self.summary}" + + suggestions = copy.deepcopy(suggestions) or [] + + if self.type.lower() == "text": + pass + + elif self.type.lower() in ["table", "figure"]: + for source, segment in self.extractions.items(): + if segment.summary: + fields["header"] += f"{CHUNK_DELIM}{segment.summary}" + + if segment.html: + if source == "nougat": + fields["text-1"] = segment.html + elif source == "unstructured": + fields["text-2"] = segment.html + elif source == "llmsherpa": + fields["text-3"] = segment.html + elif source == "deepdoctection": + fields["text-4"] = segment.html + elif source == "pdffigures2": + fields["text-5"] = segment.html + + if isinstance(self.image, str) and os.path.exists(self.image): + fields["image"] = image_to_html(self.image) + + else: + print("Skipped", self.type, metadata, self.extractions.keys()) + + # Metadata + metadata = copy.deepcopy(metadata) or {} + metadata["type"] = f"{self.type}" + if self.number: + metadata["number"] = f"{self.type} {self.number}" + if self.probability: + metadata["probability"] = self.probability + if self.page_number: + metadata["page_number"] = str(self.page_number) + + if dataset.field_by_name("metadata") and not fields.get("metadata"): + fields["metadata"] = pd.DataFrame.from_dict(metadata, orient="index").T.to_markdown(index=False) + + record = rg.FeedbackRecord(fields=fields, metadata=metadata, **kwargs) + + return record + + +Alignments.update_forward_refs() +SegmentsAlignment.update_forward_refs() + + +def merge_extractions(**extraction_sources: Dict[str, Segments]) -> Alignments: + extraction_sources = { + source: segments for source, segments in extraction_sources.items() if segments is not None and len(segments) + } + pointers = {source: 0 for source in extraction_sources} + merged_data = [] + + while any( + pointer < len(extraction_list) + for pointer, extraction_list in zip(pointers.values(), extraction_sources.values()) + ): + current_items = [] + for source, segments in extraction_sources.items(): + index = pointers[source] + if index < len(segments): + current_item = segments[index] + current_items.append((source, current_item)) + + # groups_ordering = group_segments([segment for (source, segment) in current_items], threshold=30) + # current_items = [(source, segment, number) for (source, segment), number in \ + # zip(current_items, groups_ordering)] + + # Sort by page number, then by table number, handling None values + valid_items: List[Tuple[str, Union[TableSegment, FigureSegment]]] = sorted( + current_items, key=lambda x: (x[1].page_number, x[1].number or float("inf")) + ) + if not valid_items: + break + + current_page_number = valid_items[0][1].page_number + filtered_items: List[Tuple[str, TextSegment]] = [ + (source, segment) for source, segment, *number in valid_items if segment.page_number == current_page_number + ] + + # Combine headers for unique values + unique_headers = set(segment.header.strip() for (source, segment) in filtered_items if segment.header) + combined_header = CHUNK_DELIM.join(find_longest_superstrings(unique_headers, similarity_threshold=90)) + summary = next( + (segment.summary for source, segment in filtered_items if getattr(segment, "summary", None)), None + ) + number = next((segment.number for source, segment in filtered_items if getattr(segment, "number", None)), None) + image = next((segment.image for source, segment in filtered_items if getattr(segment, "image", None)), None) + probabilities = [ + segment.probability for source, segment in filtered_items if getattr(segment, "probability", None) + ] + + # Create SegmentsAlignment object + segment_alignment = SegmentsAlignment( + header=combined_header.strip(), + summary=summary, + page_number=current_page_number, + number=number, + extractions={source: segment for source, segment in filtered_items}, + image=image, + type=filtered_items[0][1].type, + probability=max(probabilities, default=None), + ) + merged_data.append(segment_alignment) + + # Increment pointers for all sources that had the current header + for source, _ in filtered_items: + pointers[source] += 1 + + return Alignments(items=merged_data) + + +def group_segments_by_similarity(segments: List[TextSegment], threshold=80.0) -> List[int]: + if len(segments) == 1: + return [0] + + groups = [] + for seg in segments: + for group in groups: + if max(fuzz.ratio(seg.text_cleaned(), group_item.text_cleaned()) for group_item in group) >= threshold: + group.append(seg) + break + else: + groups.append([seg]) + + groups.sort(key=len, reverse=True) + + ordered_groups = [] + for seg in segments: + for i, group in enumerate(groups): + if seg in group: + ordered_groups.append(i) + break + + return ordered_groups + + +def find_matching_text_elem( + text: str, elements: List[Element], page_number: int, start_index: Optional[int] = None, thresh=0.5 +) -> List[int]: + matches = [] + + for i in range(start_index or 0, len(elements)): + elem = elements[i] + if isinstance(elem, (FigureCaption, Header, Footer, UnstructuredTable, Image)): + continue + + if page_number <= elem.metadata.page_number <= page_number + 1: + ratio = difflib.SequenceMatcher(None, text, elem.text).ratio() + if ratio > thresh: + matches.append(i) + + return matches + + +# @DeprecationWarning +# def merge_text_segments(doc: Document, elements: List[Element], thresh=0.5) -> List[SegmentsAlignment]: +# def llmsherpa_section_to_segment(section: Section) -> TextSegment: +# assert not isinstance(section, Table), f"section must be a Section, got {type(section)}" +# +# if section.children: +# text = CHUNK_DELIM.join(get_paragraphs(section)) +# else: +# text = None +# +# header = section.title +# +# return TextSegment( +# header=header, +# page_number=section.page_idx + 1, +# text=text, +# html=section.to_html(), +# source='llmsherpa', +# original=section, +# ) +# +# merged = [] +# +# element_index = 0 +# for section in doc.sections(): +# section_page = section.page_idx + 1 +# section_title = section.title +# if not section.children: +# continue +# +# matched_elements: List[Element] = [] +# for child in section.children: +# if isinstance(child, llmsherpa.readers.layout_reader.Paragraph): +# child_text = child.to_text(include_children=True, recurse=True).replace('\n', ' ') +# +# matched_elem_idx = find_matching_text_elem(child_text, elements, page_number=section_page, +# start_index=element_index, thresh=thresh) or [element_index] +# matched_elements.extend([elements[i] for i in matched_elem_idx if i < len(elements)]) +# +# else: +# continue +# +# element_index = max(matched_elem_idx) + 1 +# +# # If texts are similar enough, create a record +# if matched_elements: +# record = SegmentsAlignment( +# header=section_title, +# llmsherpa=[llmsherpa_section_to_segment(section)], +# unstructured=[TextSegment.from_unstructured(elem, header=section_title) for elem in matched_elements]) +# merged.append(record) +# +# return merged diff --git a/argilla-v1/src/extralit_v1/preprocessing/document.py b/argilla-v1/src/extralit_v1/preprocessing/document.py new file mode 100644 index 000000000..86e5b23a8 --- /dev/null +++ b/argilla-v1/src/extralit_v1/preprocessing/document.py @@ -0,0 +1,284 @@ +import glob +import os +from os.path import join +from typing import Tuple, Optional + +import dill +import pandas as pd + +from extralit_v1.preprocessing.segment import Segments +from extralit_v1.storage.files import FileHandler, StorageType + +__all__ = [ + "load_segments", + "create_or_load_unstructured_segments", + "create_or_load_llmsherpa_segments", + "create_or_load_nougat_segments", + "create_or_load_pdffigures2_segments", + "create_or_load_deepdoctection_segments", +] + + +def load_segments(file_handler: FileHandler, path: str) -> Tuple[Segments, Segments, Segments]: + texts = Segments() + tables = Segments() + figures = Segments() + + if file_handler.exists(join(path, "texts.json")): + texts = Segments.parse_raw(file_handler.read_text(join(path, "texts.json"))) + if file_handler.exists(join(path, "tables.json")): + tables = Segments.parse_raw(file_handler.read_text(join(path, "tables.json"))) + if file_handler.exists(join(path, "figures.json")): + figures = Segments.parse_raw(file_handler.read_text(join(path, "figures.json"))) + + return texts, tables, figures + + +def create_or_load_unstructured_segments( + paper: pd.Series, + file_handler: FileHandler, + preprocessing_path="data/preprocessing/", + load_only=True, + redo=False, + save=True, +) -> Tuple[Optional[Segments], Optional[Segments], Optional[Segments]]: + from extralit_v1.preprocessing.methods import unstructured + from unstructured.partition.pdf import partition_pdf + from unstructured.staging.base import elements_to_json, elements_from_json + + cache_path: str = join(preprocessing_path, "unstructured", paper.name) + model_output_path = join(cache_path, "elements.json") + figures_path = join(cache_path, "figures") + + if file_handler.exists(cache_path) and load_only: + return load_segments(file_handler, cache_path) + + if not file_handler.exists(model_output_path) or redo: + print(f"Unstructured {paper.name}: {cache_path}", flush=True) + os.makedirs(figures_path, exist_ok=True) + elements = partition_pdf( + filename=paper.file_path, + strategy="hi_res", + infer_table_structure=True, + chunking_strategy={ + "multipage_sections": True, + "include_metadata": True, + "combine_text_under_n_chars": 500, + }, + extract_images_in_pdf=True, + image_output_dir_path=figures_path, + extract_image_block_output_dir=figures_path, + pdf_image_dpi=600, + ) + if save: + elements_to_json(elements, filename=model_output_path) + else: + elements = elements_from_json(model_output_path) + + texts = unstructured.get_text_segments(elements) + tables = unstructured.get_table_segments(elements, output_dir=figures_path) + figures = unstructured.get_figure_segments(elements) + + if save: + file_handler.write_text(join(cache_path, "texts.json"), texts.json()) + file_handler.write_text(join(cache_path, "tables.json"), tables.json()) + file_handler.write_text(join(cache_path, "figures.json"), figures.json()) + + return texts, tables, figures + + +def create_or_load_llmsherpa_segments( + paper: pd.Series, + file_handler: FileHandler, + preprocessing_path="data/preprocessing/", + load_only=True, + redo=False, + save=True, +) -> Tuple[Optional[Segments], Optional[Segments], Optional[Segments]]: + from extralit_v1.preprocessing.methods import llmsherpa + from llmsherpa.readers import LayoutPDFReader + + cache_path: str = join(preprocessing_path, "llmsherpa", paper.name) + model_output_path = join(cache_path, "document.pkl") + + if file_handler.exists(cache_path) and load_only: + return load_segments(file_handler, cache_path) + + if not file_handler.exists(model_output_path) or redo: + print(f"Llmsherpa {paper.name}: {cache_path}", flush=True) + os.makedirs(cache_path, exist_ok=True) + pdf_reader = LayoutPDFReader( + "https://readers.llmsherpa.com/api/document/developer/parseDocument?renderFormat=all" + ) + try: + document = pdf_reader.read_pdf(paper.file_path) + except Exception as e: + print(e) + return None, None, None + if save: + with open(model_output_path, "wb") as file: + dill.dump(document, file) + else: + with open(model_output_path, "rb") as file: + document = dill.load(file) + + texts = llmsherpa.get_text_segments(document) + tables = llmsherpa.get_table_segments(document) + + if save: + file_handler.write_text(join(cache_path, "texts.json"), texts.json()) + file_handler.write_text(join(cache_path, "tables.json"), tables.json()) + + return texts, tables, None + + +def create_or_load_nougat_segments( + paper: pd.Series, + file_handler: FileHandler, + preprocessing_path="data/preprocessing/", + nougat_model=None, + load_only=True, + redo=False, + save=True, +) -> Tuple[Optional[Segments], Optional[Segments], Optional[Segments]]: + from extralit_v1.preprocessing.methods import nougat + + cache_path: str = join(preprocessing_path, "nougat", paper.name) + model_output_path = join(cache_path, "predictions.json") + + if file_handler.exists(cache_path) and load_only: + return load_segments(file_handler, cache_path) + + if not file_handler.exists(model_output_path) or redo: + from extralit_v1.preprocessing.text import NougatOCR + + print(f"Nougat {paper.name}: {cache_path}", flush=True) + assert isinstance(nougat_model, NougatOCR), f"Invalid Nougat model: {nougat_model}" + + predictions = nougat_model.predict(paper.file_path) + output = nougat.NougatOutput(reference=paper.name, pages=predictions) + if save: + file_handler.write_text(model_output_path, output.json()) + else: + output = nougat.NougatOutput.parse_file(model_output_path) + + texts = nougat.get_text_segments(output.pages) + tables = nougat.get_table_segments(output.pages) + + if save: + file_handler.write_text(join(cache_path, "texts.json"), texts.json()) + file_handler.write_text(join(cache_path, "tables.json"), tables.json()) + + return texts, tables, None + + +def create_or_load_pdffigures2_segments( + paper: pd.Series, + file_handler: FileHandler, + preprocessing_path="data/preprocessing/", + jar_path="~/bin/pdffigures2.jar", + load_only=True, + redo=False, + save=True, +) -> Tuple[Optional[Segments], Optional[Segments], Optional[Segments]]: + if not os.path.exists(jar_path): + raise FileNotFoundError(f"pdffigures2 jar not found: {jar_path}") + + cache_path: str = join(preprocessing_path, "pdffigure2", paper.name) + _, file_name_ext = os.path.split(paper.file_path) + file_name, _ = os.path.splitext(file_name_ext) + model_output_path = join(cache_path, f"{file_name}.json") + + if file_handler.exists(join(cache_path, "figures.json")) and load_only: + return load_segments(file_handler, cache_path) + + if not file_handler.exists(model_output_path) or redo: + print(f"pdffigures2 {paper.name}: {cache_path}", flush=True) + os.makedirs(cache_path, exist_ok=True) + command = "java -jar {jar_path} {file_path} -m {output_dir} -d {output_dir} --figure-format png" + os.system(command.format(jar_path=jar_path, file_path=paper.file_path, output_dir=cache_path.rstrip("/") + "/")) + + try: + segments = Segments.from_pdffigures2(model_output_path) + except Exception as e: + print(e) + return None, None, None + + tables = Segments() + figures = Segments() + for segment in segments.items: + if segment.type == "table": + tables.items.append(segment) + elif segment.type == "figure": + figures.items.append(segment) + + if save: + file_handler.write_text(join(cache_path, "tables.json"), tables.json()) + file_handler.write_text(join(cache_path, "figures.json"), figures.json()) + + return None, tables, figures + + +def create_or_load_deepdoctection_segments( + paper: pd.Series, + file_handler: FileHandler, + preprocessing_path="data/preprocessing/", + load_only=True, + redo=False, + save=True, +) -> Tuple[Optional[Segments], Optional[Segments], Optional[Segments]]: + from extralit_v1.preprocessing.methods import deepdoctection + import deepdoctection as dd + + cache_path: str = join(preprocessing_path, "deepdoctection", paper.name) + model_output_path = join(cache_path, "page_1.json") + + if file_handler.exists(cache_path) and load_only: + return load_segments(file_handler, cache_path) + + if not file_handler.exists(model_output_path) or redo: + print(f"Deepdoctection {paper.name}: {cache_path}", flush=True) + os.makedirs(cache_path, exist_ok=True) + + os.environ["USE_DD_PILLOW"] = "True" + os.environ["USE_DD_OPENCV"] = "False" + + analyzer = dd.get_dd_analyzer( + config_overwrite=[ + "PT.LAYOUT.WEIGHTS=microsoft/table-transformer-detection/pytorch_model.bin", + "PT.ITEM.WEIGHTS=microsoft/table-transformer-structure-recognition-v1.1-all/pytorch_model.bin", + "PT.ITEM.FILTER=['table']", + "USE_PDF_MINER=True", + "USE_OCR=True", + ] + ) + + try: + df = analyzer.analyze(path=paper.file_path) + df.reset_state() + except Exception as e: + print(e) + return None, None, None + + doc = iter(df) + pages = [] + for page_num in range(1, len(df) + 1): + try: + page: dd.Page = next(doc) + pages.append(page) + page.save(image_to_json=True, path=join(cache_path, f"page_{page_num}.json")) if save else None + except StopIteration: + break + else: + pages = [] + for path in sorted(glob.glob(join(cache_path, "page_*.json"))): + pages.append(dd.Page.from_file(path)) + + tables = deepdoctection.get_table_segments(pages, output_dir=join(cache_path, "tables"), redo=redo) + # figures = deepdoctection.get_figure_segments(pages) + + if save: + file_handler.write_text(join(cache_path, "tables.json"), tables.json()) + # file_handler.write_text(join(cache_path, 'figures.json'), figures.json()) + + return None, tables, None diff --git a/argilla-v1/src/extralit_v1/preprocessing/figures.py b/argilla-v1/src/extralit_v1/preprocessing/figures.py new file mode 100644 index 000000000..a02e1f9e6 --- /dev/null +++ b/argilla-v1/src/extralit_v1/preprocessing/figures.py @@ -0,0 +1,82 @@ +import base64 +import io +import logging +from typing import Optional + +from PIL import Image +from bs4 import BeautifulSoup +from pydantic import BaseModel, Field + +from extralit_v1.extraction.staging import heal_json + + +class FigureExtractionResponse(BaseModel): + """Figure digitization and summary of scientific chart or figure.""" + + summary: str = Field( + None, description="Summary of the figure, detailing the variables visualized and the observations compared." + ) + html: str = Field( + None, + description="HTML table of data extracted from the chart with the same structure as the original figure, " + "with exact data values omitted. If the provided image is a map or a picture, this field will be empty.", + ) + + def __init__(self, **data): + super().__init__(**data) + if "html" in data: + self.html = clean_html_table(data["html"]) + + @classmethod + def parse_raw(cls, b, **kwargs): + healed_json_string = heal_json(b) + try: + output = super().parse_raw(healed_json_string, **kwargs) + except Exception as e: + logging.error(f"Error parsing {cls.__name__}: {e}\n" f'Given: "{healed_json_string}"') + return cls(items=[]) + + return output + + +def encode_image(image_path: str, max_size=(1000, 1000), resize_only=True) -> Optional[str]: + # Open an image file + with Image.open(image_path) as img: + original_size = img.size + img.thumbnail(max_size) + new_size = img.size + + if resize_only: + if original_size != new_size: + img.save(image_path) + else: + byte_arr = io.BytesIO() + img.save(byte_arr, format="PNG") + + return base64.b64encode(byte_arr).decode("utf-8") + + +def clean_html_table(html_string: str) -> Optional[str]: + if not html_string: + return None + + try: + # Parse the HTML + soup = BeautifulSoup(html_string, "html.parser") + + # Find the first table in the HTML + table = soup.find("table") + + # If a table was found + if table is not None: + caption = table.find("caption") + if caption is not None: + caption.decompose() + + # If a table was found, return its HTML as a string + if table is not None: + return str(table) + else: + return None + except: + return None diff --git a/argilla/src/extralit/preprocessing/methods/__init__.py b/argilla-v1/src/extralit_v1/preprocessing/methods/__init__.py similarity index 100% rename from argilla/src/extralit/preprocessing/methods/__init__.py rename to argilla-v1/src/extralit_v1/preprocessing/methods/__init__.py diff --git a/argilla-v1/src/extralit_v1/preprocessing/methods/deepdoctection.py b/argilla-v1/src/extralit_v1/preprocessing/methods/deepdoctection.py new file mode 100644 index 000000000..ddce11c8c --- /dev/null +++ b/argilla-v1/src/extralit_v1/preprocessing/methods/deepdoctection.py @@ -0,0 +1,86 @@ +import os +from typing import List + +from PIL import Image +from deepdoctection import Page, LayoutType, ImageAnnotationBaseView + +from extralit_v1.convert.pdf import extract_image +from extralit_v1.preprocessing.segment import TableSegment, FigureSegment, Segments, Coordinates +from extralit_v1.preprocessing.tables import get_table_header_footer + + +def get_table_segments(pages: List[Page], output_dir=None, redo=True) -> Segments: + os.makedirs(output_dir, exist_ok=True) + segments = Segments() + for page_num, page in enumerate(pages, start=1): + if page is None: + continue + page_image = Image.fromarray(page.viz(show_layouts=False, show_cells=False, show_token_class=False)) + + captured_indices = set() + for table_num, table in enumerate(page.tables, start=1): + coordinates = get_coordinates(table, page) + image_path = extract_image( + page_image, coordinates, title=f"table_{page_num}_{table_num}.png", output_dir=output_dir, redo=redo + ) + + table_index = next( + (i for i, ann in enumerate(page.annotations) if ann._annotation_id == table.annotation_id), None + ) + + # Find the table header + header, footer = get_table_header_footer( + page.annotations, + start_index=table_index, + look_ahead=2, + get_text_fn=lambda x: x.text, + header_pattern=r"(?i)(Table)\s?([Il|\d]+\.?)(.*|$)", + footer_pattern=r".*", + header_filter_fn=lambda x: x._category_name == LayoutType.text + and coordinates.is_vstacked(get_coordinates(x, page), width="smaller") != False, + footer_filter_fn=lambda x: x._category_name == LayoutType.text + and coordinates.is_vstacked(get_coordinates(x, page), width="smaller") != False, + captured_indices=captured_indices, + ) + # print(f'\n> table {table_num}, page {page_num}', '\n\theader:', header, '\n\tfooter:' if footer else '', + # footer) + + segment = TableSegment( + header=(header + footer).strip(), + page_number=page_num, + text=table.text, + html=table.html, + image=image_path, + probability=table.score, + coordinates=coordinates, + source="deepdoctection", + original=table, + ) + segments.items.append(segment) + + return segments + + +def get_coordinates(layout: ImageAnnotationBaseView, page: Page): + bbox = layout.bounding_box + coordinates = Coordinates( + points=[[bbox.ulx, bbox.uly], [bbox.lrx, bbox.uly], [bbox.ulx, bbox.lry], [bbox.lrx, bbox.lry]], + layout_width=page.width, + layout_height=page.height, + ) + return coordinates + + +def get_figure_segments(figures: List[Page]) -> Segments: + segments = Segments() + for figure in figures: + segment = FigureSegment( + header=figure.name, + page_number=figure.page_idx + 1, + image=figure.image, + source="deepdoctection", + original=figure, + ) + segments.items.append(segment) + + return segments diff --git a/argilla-v1/src/extralit_v1/preprocessing/methods/llmsherpa.py b/argilla-v1/src/extralit_v1/preprocessing/methods/llmsherpa.py new file mode 100644 index 000000000..f1c8b3927 --- /dev/null +++ b/argilla-v1/src/extralit_v1/preprocessing/methods/llmsherpa.py @@ -0,0 +1,91 @@ +from typing import List + +import spacy +from llmsherpa.readers.layout_reader import Section, Document, Paragraph, Table + +from extralit_v1.convert.html_table import fix_llmsherpa_html_table +from extralit_v1.preprocessing.segment import TableSegment, TextSegment, Segments +from extralit_v1.preprocessing.tables import table_extraction_qc, get_table_header_footer + + +def get_table_segments(document: Document, caption_pattern=r"(?i)(Table)\s?(\d+\.?)(.*|$)") -> Segments: + tables = Segments() + captions_taken = set() + for table in document.tables(): + if not isinstance(table, Table): + continue + + sections: List[Section] = table.parent.children + + header, footer = get_table_header_footer( + elements=sections, + start_index=sections.index(table), + look_ahead=5, + get_text_fn=lambda x: x.to_text(), + current_pattern=r"(?i)(Table)\s?(\d+\.?)(.*?)(?=\|)", + header_pattern=r"(?i)(Table)\s?([Il|\d]+\.?)(.*|$)", + footer_pattern=r"(?i)(Table)\s?([Il|\d]+\.?)(.*|$)", + header_filter_fn=lambda x: isinstance(x, Paragraph) and table.page_idx == x.page_idx, + footer_filter_fn=lambda x: isinstance(x, Paragraph) and table.page_idx == x.page_idx, + captured_indices=captions_taken, + ) + # print(f'\n> page {table.page_idx}', '\n\theader:', header, '\n\tfooter:' if footer else '', footer) + + html = table.to_html().strip() + if not html: + continue + + segment = TableSegment( + header=(header + footer).strip(), + page_number=table.page_idx + 1, + text=table.to_text(), + html=fix_llmsherpa_html_table(html), + source="llmsherpa", + original=table, + ) + if not table_extraction_qc(segment): + continue + tables.items.append(segment) + + return tables + + +def get_paragraphs(section: Section) -> List[str]: + text_chunks = [] + for child in section.children: + if not isinstance(child, Paragraph): + continue + + text = child.to_text(include_children=True, recurse=True).replace("\n", " ") + text_chunks.append(text) + + return text_chunks + + +def get_text_segments(document: Document) -> Segments: + segments = Segments() + nlp = spacy.load("en_core_web_sm") + for section in document.sections(): + if not isinstance(section, Section) or isinstance(section, Table): + continue + + if section.children: + text = "".join(get_paragraphs(section)) + else: + continue + + if not any(nlp(text).ents): + continue + + segment = TextSegment( + header=section.title, + page_number=section.page_idx + 1, + text=text, + html=section.to_html().replace("—", "-").replace("·", "."), + source="llmsherpa", + original=section, + ) + + segments.items.append(segment) + + return segments diff --git a/argilla-v1/src/extralit_v1/preprocessing/methods/nougat.py b/argilla-v1/src/extralit_v1/preprocessing/methods/nougat.py new file mode 100644 index 000000000..e690b771a --- /dev/null +++ b/argilla-v1/src/extralit_v1/preprocessing/methods/nougat.py @@ -0,0 +1,130 @@ +import logging +import re +from typing import List + +import pypandoc +from llama_index.core.schema import NodeRelationship, RelatedNodeInfo +from pydantic import BaseModel + +from extralit_v1.convert.html_table import remove_html_styles +from extralit_v1.convert.text import remove_longest_repeated_subsequence +from extralit_v1.convert.text import remove_markdown_from_string +from extralit_v1.preprocessing.segment import TableSegment, TextSegment, Segments + + +class NougatOutput(BaseModel): + reference: str + pages: List[str] + + +def get_text_segments(pages: List[str], title="Title") -> Segments: + segments = Segments() + current_segment = None + stored_header = "" + parents_stack = [] + + for page_number, page in enumerate(pages, start=1): + page = remove_longest_repeated_subsequence(page, min_substring_len=1, min_repeats=10) + page = re.sub(r"\n*\\begin{table}.*?\\end{table}\n.*?(\n|$)", "", page, flags=re.DOTALL) + page = re.sub(r"\n*\\begin{tabular}.*?\\end{tabular}\n.*?(\n|$)", "", page, flags=re.DOTALL) + if not current_segment and page_number == 1: + current_segment = TextSegment(header=title, level=1, page_number=page_number, text="") + + for line in page.split("\n"): + header_match = re.match(r"(#+)\s*(.*)", line) + if header_match: + if current_segment and (current_segment.text or current_segment.relationships): + segments.items.append(current_segment) + level = len(header_match.group(1)) + while parents_stack and parents_stack[-1].level >= level: + parents_stack.pop() + parent = parents_stack[-1] if parents_stack else None + current_segment = TextSegment( + header=f"{stored_header}{header_match.group(2)}", level=level, page_number=page_number, text="" + ) + + if parent: + current_segment.relationships[NodeRelationship.PARENT] = RelatedNodeInfo( + node_id=parent.id, + ) + + parent.relationships.setdefault(NodeRelationship.CHILD, []).append( + RelatedNodeInfo( + node_id=current_segment.id, + ) + ) + stored_header = "" + parents_stack.append(current_segment) + + elif current_segment: + current_segment.text += line + "\n" + + segments.make_headers_unique() + return segments + + +def correct_column_definition(latex_table: str) -> str: + # Split the table into rows + rows = re.split(r"\\", latex_table) + + # Find the row with the maximum number of columns + max_columns = max(row.count("&") for row in rows if r"\begin" not in row and r"\end" not in row) + + # Generate the corrected column definition + corrected_definition = " ".join(["c" for _ in range(max_columns + 1)]) + + # Replace the original column definition in the \begin{tabular} line + corrected_table = re.sub(r"\\begin{tabular}{.*?}", r"\\begin{tabular}{" + corrected_definition + "}", latex_table) + + return corrected_table + + +def get_table_segments(pages: List[str]) -> Segments: + segments = Segments() + + for page_number, page_text in enumerate(pages, start=1): + # Regular expression pattern for LaTeX tables + pattern = r"(\\begin{table}.*?\\end{tabular}(.*?)\\end{table})\n(.*?)(\n|$)" + matches = re.findall(pattern, page_text, re.DOTALL) + if not matches: + pattern = r"(\\begin{tabular}.*?\\end{tabular}(.*?))\n(.*?)(\n|$)" + matches = re.findall(pattern, page_text, re.DOTALL) + + for match in matches: + table_content, footer, caption, *_ = match + table_html, table_markdown = "", "" + + try: + table_html = pypandoc.convert_text(table_content, "html", format="latex") + except Exception as e: + try: + table_content = correct_column_definition(table_content) + table_html = pypandoc.convert_text(table_content, "html", format="latex") + except Exception as e: + logging.warning(f"Could not convert table to HTML: {e.__str__()}") + finally: + if table_html: + table_html = remove_html_styles(table_html) + table_html = remove_markdown_from_string(table_html) + else: + continue + + if not table_html.startswith(r""): + continue + else: + table_html = table_html.replace("—", "-").replace("·", ".") + + caption = caption.strip() + (("\n" + footer.strip()) if footer else "") + + # Create a Segment object + segment = TableSegment( + header=caption, + text=table_markdown, + html=table_html, + page_number=page_number, + original=table_content, + source="nougat", + ) + segments.items.append(segment) + + return segments diff --git a/argilla-v1/src/extralit_v1/preprocessing/methods/unstructured.py b/argilla-v1/src/extralit_v1/preprocessing/methods/unstructured.py new file mode 100644 index 000000000..9ff0c4a53 --- /dev/null +++ b/argilla-v1/src/extralit_v1/preprocessing/methods/unstructured.py @@ -0,0 +1,152 @@ +import re +from os.path import join +from typing import List + +import spacy +from pdf2image import convert_from_path +from unstructured.documents.elements import Element, FigureCaption, Table, Image, Text + +from extralit_v1.convert.pdf import extract_image +from extralit_v1.preprocessing.segment import TableSegment, FigureSegment, TextSegment, Segments, Coordinates +from extralit_v1.preprocessing.tables import table_extraction_qc, get_table_header_footer + + +def get_table_segments(elements: List[Element], max_caption_look_head=5, output_dir=None, redo=False) -> Segments: + try: + pdf_path = join(elements[0].metadata.file_directory, elements[0].metadata.filename) + page_image = convert_from_path(pdf_path, dpi=450) + except Exception as e: + print(e) + page_image = None + + tables = Segments() + captions_taken = set() + for i, elem in enumerate(elements): + if not isinstance(elem, Table): + continue + + page_number = elem.metadata.page_number + coordinates = Coordinates(**elem.metadata.coordinates.to_dict()) + + # Find the closest FigureCaption element nearest either behind or ahead of the index `i` + header, footer = get_table_header_footer( + elements, + start_index=i, + look_ahead=max_caption_look_head, + get_text_fn=lambda x: x.text, + header_pattern=r"(?i)(Table)\s?([Il|\d]+\.?)(.*|$)", + footer_pattern=r".*", + header_filter_fn=lambda x: isinstance(x, (Text, FigureCaption)) + and page_number == x.metadata.page_number + and coordinates.is_vstacked(Coordinates(**x.metadata.coordinates.to_dict()), width="smaller") != False, + footer_filter_fn=lambda x: isinstance(x, (Text, FigureCaption)) + and page_number == x.metadata.page_number + and coordinates.is_vstacked(Coordinates(**x.metadata.coordinates.to_dict()), width="smaller") != False, + captured_indices=captions_taken, + ) + + image_path = None + if page_image and output_dir and page_number < len(page_image): + image_path = extract_image( + page_image[page_number - 1], + coordinates=coordinates, + title=f"table_{i}", + output_dir=output_dir, + redo=redo, + ) + + segment = TableSegment( + header=header.strip() if header else None, + page_number=page_number, + coordinates=coordinates, + image=image_path, + probability=getattr(elem.metadata, "detection_class_prob", None), + text=elem.text, + html=elem.metadata.text_as_html.replace("—", "-").replace("·", "."), + source="unstructured", + original=elem, + ) + + if table_extraction_qc(segment): + tables.items.append(segment) + + return tables + + +def get_figure_segments(elements: List[Element], skip_empty_header=True, max_caption_look_head=5) -> Segments: + figures = Segments() + captions_taken = set() + for i, elem in enumerate(elements): + if not isinstance(elem, Image): + continue + + # Find the closest FigureCaption element nearest either behind or ahead of the index `i` + title_el = None + for j in range(1, max_caption_look_head): + if i - j > 0 and isinstance(elements[i - j], FigureCaption) and i - j not in captions_taken: + title_el = elements[i - j].text + captions_taken.add(i - j) + break + elif i + j < len(elements) and isinstance(elements[i + j], FigureCaption) and i + j not in captions_taken: + title_el = elements[i + j].text + captions_taken.add(i + j) + break + + if skip_empty_header and (not title_el or "fig" not in title_el.lower()): + continue + + segment = FigureSegment( + header=title_el, + page_number=elem.metadata.page_number, + coordinates=elem.metadata.coordinates.to_dict(), + image=getattr(elem.metadata, "image_path", None), + probability=getattr(elem.metadata, "detection_class_prob", None), + text=elem.text, + html=elem.metadata.text_as_html, + source="unstructured", + original=elem, + ) + figures.items.append(segment) + + return figures + + +def get_text_segments(elements: List[Element]) -> Segments: + segments = Segments() + parent_map = {} + watermark_pattern = r"(?:\b[\w/]\s)+" + nlp = spacy.load("en_core_web_sm") + + for elem in elements: + if len(elem.text) < 5 or isinstance(elem, (Table, FigureCaption, Image)): + continue + elif re.match(watermark_pattern, elem.text): + continue + elif "reference" not in elem.text.lower() and not any(nlp(elem.text).ents): + continue + + segment = TextSegment( + level=getattr(elem.metadata, "level", None), + text=elem.text, + page_number=elem.metadata.page_number, + coordinates=elem.metadata.coordinates.to_dict(), + probability=getattr(elem.metadata, "detection_class_prob", None), + source="unstructured", + original=elem, + ) + segments.items.append(segment) + + # parent_id = getattr(elem.metadata, 'parent_id', None) + # if parent_id: + # if parent_id not in parent_map: + # parent_map[parent_id] = [] + # parent_map[parent_id].append(segment) + # + # for segment in segments: + # parent_id = getattr(segment.original.metadata, 'parent_id', None) + # if parent_id and parent_id in parent_map: + # segment.children = parent_map[parent_id] + # + # segments = [segment for segment in segments if segment.children] + + return segments diff --git a/argilla-v1/src/extralit_v1/preprocessing/segment.py b/argilla-v1/src/extralit_v1/preprocessing/segment.py new file mode 100644 index 000000000..8134ebff7 --- /dev/null +++ b/argilla-v1/src/extralit_v1/preprocessing/segment.py @@ -0,0 +1,278 @@ +import json +import logging +import os.path +import uuid +from typing import Optional, Any, List, Union, Dict + +from llama_index.core.output_parsers import PydanticOutputParser +from llama_index.core.program import MultiModalLLMCompletionProgram +from llama_index.core.readers import SimpleDirectoryReader +from llama_index.core.schema import NodeRelationship, RelatedNodeType +from llama_index.multi_modal_llms.openai import OpenAIMultiModal +from pydantic.v1 import BaseModel, Field, validator + +from extralit_v1.convert.html_table import html_table_to_json, html_to_df, llmsherpa_html_to_df +from extralit_v1.extraction import prompts +from extralit_v1.preprocessing.figures import encode_image, FigureExtractionResponse +from extralit_v1.preprocessing.tables import extract_table_number + +CHUNK_DELIM = "\n\n---\n" + + +class Segments(BaseModel): + items: List[Union["TextSegment", "TableSegment", "FigureSegment"]] = Field( + default_factory=list, description="List of segments" + ) + + def get(self, id: str, header: str = None, default=None): + for item in self.items: + if item.id == id or (header and item.header == header): + return item + + return default + + def make_headers_unique(self) -> None: + header_dict = {} + + for segment in self.items: + if segment.header in header_dict: + parent = segment.relationships.get(NodeRelationship.PARENT) + if parent: + parent_segment = self.get(parent.node_id) + if parent_segment: + segment.header = f"{parent_segment.header}: {segment.header}" + print(segment.id, segment.header) + else: + header_dict[segment.header] = segment + + def __repr_str__(self, join_str: str) -> str: + return "\n " + f"{join_str}\n ".join(f"{type(item).__name__}({item})" for item in self.items) + + @validator("items", pre=True, each_item=True) + def parse_segments(cls, v): + if not isinstance(v, dict): + v = v.dict() + + segment_type = v.get("type", "").lower() + if segment_type in {"figure", "image"}: + return FigureSegment(**v) + elif segment_type == "table" or "html" in v: + return TableSegment(**v) + else: + return TextSegment(**v) + + @classmethod + def from_pdffigures2(cls, json_file: str) -> "Segments": + with open(json_file, "r") as f: + data = json.load(f) + + items = [] + for item in data: + mapped_item = { + "header": item["caption"], + "type": item["figType"].lower(), + "text": " ".join(item["imageText"]), + "image": item["renderURL"], + "page_number": item["page"] + 1, + "coordinates": { + "points": [ + [item["regionBoundary"]["x1"], item["regionBoundary"]["y1"]], + [item["regionBoundary"]["x2"], item["regionBoundary"]["y1"]], + [item["regionBoundary"]["x1"], item["regionBoundary"]["y2"]], + [item["regionBoundary"]["x2"], item["regionBoundary"]["y2"]], + ], + "layout_width": item["regionBoundary"]["x2"] - item["regionBoundary"]["x1"], + "layout_height": item["regionBoundary"]["y2"] - item["regionBoundary"]["y1"], + "system": "points", + }, + "source": "pdffigures2", + } + items.append(cls.parse_segments(mapped_item)) + + items = sorted(items, key=lambda x: (x.page_number, x.number or float("inf"))) + + return cls(items=items) + + @property + def duration(self): + return sum([item.duration or 0 for item in self.items if item.duration and item.duration < 1000]) + + def __getitem__(self, index): + return self.items[index] + + def __len__(self): + return len(self.items) + + +class Coordinates(BaseModel): + points: List[List[float]] = Field( + ..., description="List of 4 points, e.g. [[x1, y1], [x2, y1], [x1, y2], [x2, y2]]" + ) + layout_width: Optional[int] = Field(None, description="Width of the layout") + layout_height: Optional[int] = Field(None, description="Height of the layout") + system: Optional[str] = Field(description="System of coordinates") + + def __repr_str__(self, join_str: str) -> str: + return "" + + def is_vstacked(self, other: "Coordinates", width: Optional[str] = "same", tol=0.05) -> Optional[bool]: + if not self.points or not other.points: + return None + + if self.layout_width and self.layout_width == other.layout_width: + tolerance = self.layout_width * tol # 1% of the layout width + else: + tolerance = 10 # pixels + + # Get the x-coordinates of the current bounding box + x1_self = self.points[0][0] + x2_self = self.points[1][0] + + # Get the x-coordinates of the other bounding box + x1_other = other.points[0][0] + x2_other = other.points[1][0] + + # Check if the x-coordinates of the two bounding boxes are approximately equal + if width == "smaller": + return abs(x1_self - x1_other) <= tolerance and (x2_self + tolerance) > x2_other + elif width == "larger": + return abs(x1_self - x1_other) <= tolerance and (x2_self + tolerance) <= x2_other + elif width == "same": + return abs(x1_self - x1_other) <= tolerance and abs(x2_self - x2_other) <= tolerance + else: + return abs(x1_self - x1_other) <= tolerance + + +class TextSegment(BaseModel): + id: str = Field( + default_factory=lambda: str(uuid.uuid4()), description="Unique identifier of the segment", repr=False + ) + + header: Optional[str] = Field(None, description="Header of the element", example="Abstract") + text: str = Field(..., description="Content as plain text", repr=False) + summary: Optional[str] = Field(None, description="Summary of the content") + page_number: Optional[int] = Field(None, description="Page number of the segment") + coordinates: Optional["Coordinates"] = Field( + None, description="Coordinates of the element in the document", repr=False + ) + level: Optional[int] = Field(None, description="Level of the header") + relationships: Dict[NodeRelationship, RelatedNodeType] = Field( + default_factory=dict, + description="A mapping of relationships to other segments.", + ) + source: Optional[str] = Field(None, description="Source of the element", example="llmsherpa", repr=False) + type: Optional[str] = Field("text", description="Type of the element", example="text", repr=False) + original: Optional[Any] = Field( + None, exclude=True, description="Original object from which the segment was extracted", repr=False + ) + duration: Optional[float] = Field(None, description="Duration spent in manual extraction", repr=False) + + def text_cleaned(self): + return self.text.replace(" | ", " ").replace("---", "").strip() + + def __repr_str__(self, join_str: str) -> str: + return join_str.join( + repr(v) + if a is None + else ( + f'{a}="{v[:100]}...{v[-100:]}"'.replace("\n", "") + if isinstance(v, str) and len(v) > 200 + else f"{a}={v!r}" + ) + for a, v in self.__repr_args__() + if v and a not in {"INCLUDE_METADATA_KEYS"} + ) + + +class TableSegment(TextSegment): + footer: Optional[str] = Field(None, description="Footer of the table or figure, to explain variable acronyms.") + html: Optional[str] = Field(None, description="Content as HTML structured", repr=False) + image: Optional[str] = Field(None, description="URL/filepath of the element's image", repr=False) + probability: Optional[float] = Field(None, description="Probability or confidence of the segment's extraction") + type: Optional[str] = Field("table", description="Type of the element", repr=False) + + @property + def number(self) -> Optional[int]: + return extract_table_number(self.header) + + def __repr_args__(self): + args = super().__repr_args__() + args.append(("number", self.number)) + return args + + def to_df(self, **kwargs): + if self.source == "llmsherpa": + df = llmsherpa_html_to_df(self.html) + else: + df = html_to_df(self.html, **kwargs) + + return df + + def to_csv(self): + df = self.to_df() + csv = df.to_csv(index=bool(df.index.name) or len(df.index.names) > 1) + return csv + + def to_json(self) -> str: + json = None + try: + if self.source == "llmsherpa": + df = llmsherpa_html_to_df(self.html) + json = df.to_json(orient="table", index=bool(df.index.name) or len(df.index.names) > 1) + + else: + json = html_table_to_json(self.html) + except Exception as e: + logging.warning(f"{e}, {self.type}, {self.html}") + + return json + + +class FigureSegment(TableSegment): + type: Optional[str] = Field("figure", description="Type of the element", repr=False) + + @property + def number(self) -> Optional[int]: + return extract_table_number(self.header, pattern=r"(?i)(fig\.?|figure)[.:\s]*([Il|\d]+)", group=2) + + def extract_html_table(self) -> Optional[FigureExtractionResponse]: + if not os.path.exists(self.image): + return None + + try: + encode_image(self.image, resize_only=True) + except Exception as e: + logging.warning(f"Failed to encode image: {e}") + return None + + openai_mm_llm = OpenAIMultiModal( + model="gpt-4o", + temperature=0.0, + max_new_tokens=2048, + image_detail="low", + max_retries=1, + ) + + llm_program = MultiModalLLMCompletionProgram.from_defaults( + image_documents=SimpleDirectoryReader(input_files=[self.image]).load_data(), + output_parser=PydanticOutputParser(FigureExtractionResponse), + prompt=prompts.FIGURE_TABLE_EXT_PROMPT_TMPL, + multi_modal_llm=openai_mm_llm, + ) + + try: + logging.info(f"Extracting figure table: {self.header}") + response: FigureExtractionResponse = llm_program(header_str=self.header) + self.html = response.html + if response.summary: + self.summary = response.summary + except Exception as e: + logging.warning(f"{e}") + return None + + return response + + +TextSegment.update_forward_refs() +Segments.update_forward_refs() +Coordinates.update_forward_refs() diff --git a/argilla-v1/src/extralit_v1/preprocessing/tables.py b/argilla-v1/src/extralit_v1/preprocessing/tables.py new file mode 100644 index 000000000..afaffc19a --- /dev/null +++ b/argilla-v1/src/extralit_v1/preprocessing/tables.py @@ -0,0 +1,124 @@ +import re +from typing import TYPE_CHECKING, List, Callable, Any, Tuple, Set, Optional + +if TYPE_CHECKING: + from extralit_v1.preprocessing.segment import TableSegment + + +def table_extraction_qc(segment: "TableSegment") -> bool: + is_valid = True + + try: + df = segment.to_df() + except Exception as e: + return False + + df = df.replace("", None) + df = df.dropna(axis=0, how="all").dropna(axis=1, how="all") + if min(df.shape) <= 1: + is_valid = False + + return is_valid + + +def get_table_header_footer( + elements: List[Any], + start_index: int, + look_ahead: int = 2, + header_pattern: str = r"(?i)(Table)\s?(\d+\.?)(.*|$)", + footer_pattern: str = None, + current_pattern: str = None, + get_text_fn: Callable[[Any], str] = lambda x: x.text, + header_filter_fn: Callable[[Any], bool] = lambda x: True, + footer_filter_fn: Callable[[Any], bool] = None, + captured_indices: Set[int] = None, +) -> Tuple[str, str]: + if start_index is None: + return "", "" + + header = "" + footer = "" + captured_indices = captured_indices or set() + this_elem = elements[start_index] + + if current_pattern: + match = re.search(current_pattern, get_text_fn(this_elem)) + if match: + header += match.group() + "\n" + captured_indices.add(start_index) + + for j in range(1, look_ahead + 1): + # Check the preceding element + pre_idx = start_index - j + if not header.strip() and pre_idx >= 0: + pre_elem = elements[pre_idx] + if pre_idx in captured_indices or (header_filter_fn and not header_filter_fn(pre_elem)): + continue + + match = re.search(header_pattern, get_text_fn(pre_elem)) + if match: + header += match.group() + "\n" + captured_indices.add(pre_idx) + + # Check the succeeding element + suc_idx = start_index + j + if footer_pattern and footer_filter_fn and not footer.strip() and suc_idx < len(elements): + suc_elem = elements[suc_idx] + if suc_idx in captured_indices or (footer_filter_fn and not footer_filter_fn(suc_elem)): + continue + + match = re.search(footer_pattern, get_text_fn(suc_elem)) + if match: + footer += match.group() + "\n" + captured_indices.add(suc_idx) + + return header.strip(), footer.strip() + + +def zigzag_indices(i: int, end: int, start=1): + """ + Zigzag indices generator. + Args: + i: + end: + start: + + Returns: + + """ + if i is None: + return [] + + for j in range(start, end + 1): + yield i - j + yield i + j + + +def extract_table_number(header: str, pattern=r"(?i)Table[:.\s]+([iIl|\d]+)", group=1) -> Optional[str]: + # Regular expression to capture digits after 'Table' + if not isinstance(header, str): + return None + + match = re.search(pattern, header) + if match: + # Convert the captured digits to an integer + try: + str_int = match.group(group).replace("I", "1").replace("l", "1").replace("|", "1") + return int(str_int) + except ValueError: + return None + + return None + + +SAMPLE_HTML_TABLE = """
+ + + + + + + + +
Column 1Column 2
Data 1Data 2
+""" diff --git a/argilla-v1/src/extralit_v1/preprocessing/text.py b/argilla-v1/src/extralit_v1/preprocessing/text.py new file mode 100644 index 000000000..988481bea --- /dev/null +++ b/argilla-v1/src/extralit_v1/preprocessing/text.py @@ -0,0 +1,184 @@ +import gc +import logging +from functools import partial +from pathlib import Path +from typing import List +from tqdm.asyncio import tqdm + +try: + import pypdf + import pypdfium2 + import torch + from nougat import NougatModel + from nougat.dataset.rasterize import rasterize_paper + from nougat.postprocessing import markdown_compatible, close_envs + from nougat.utils.checkpoint import get_checkpoint + from nougat.utils.dataset import LazyDataset, ImageDataset + from nougat.utils.device import default_batch_size, move_to_device + from torch.utils.data import ConcatDataset +except ImportError as ie: + raise ImportError("Please run `pip install extralit['ocr']` to install them.") from ie + + +class NougatOCR: + def __init__(self, model_tag="0.1.0-base", full_precision=False, markdown=True, skipping=True): + model_path = get_checkpoint(model_tag=model_tag) + self.model: NougatModel = NougatModel.from_pretrained(model_path) + self.markdown = markdown + self.skipping = skipping + + self.batch_size = default_batch_size() + self.model = move_to_device(self.model, bf16=not full_precision, cuda=self.batch_size > 0) + + if self.batch_size <= 0: + self.batch_size = 1 + + self.model.eval() + + def batch_predict(self, file_paths: List[Path]) -> List[List[str]]: + datasets = [] + + for pdf in file_paths: + if not pdf.exists(): + continue + + try: + dataset = LazyDataset( + pdf, + partial(self.model.encoder.prepare_input, random_padding=False), + ) + except pypdf.errors.PdfStreamError: + logging.info(f"Could not load file {str(pdf)}.") + continue + datasets.append(dataset) + if len(datasets) == 0: + return + + dataloader = torch.utils.data.DataLoader( + ConcatDataset(datasets), + batch_size=self.batch_size, + shuffle=False, + collate_fn=LazyDataset.ignore_none_collate, + ) + + documents = [] + predictions = [] + file_index = 0 + page_num = 0 + for i, (sample, is_last_page) in enumerate(tqdm(dataloader)): + model_output = self.model.inference(image_tensors=sample, early_stopping=self.skipping) + # check if itnrecal output is faulty + for j, output in enumerate(model_output["predictions"]): + if page_num == 0: + logging.info( + "Processing file %s with %i pages" % (datasets[file_index].name, datasets[file_index].size) + ) + page_num += 1 + if output.strip() == "[MISSING_PAGE_POST]": + # uncaught repetitions -- most likely empty page + predictions.append(f"\n\n[MISSING_PAGE_EMPTY:{page_num}]\n\n") + elif self.skipping and model_output["repeats"][j] is not None: + if model_output["repeats"][j] > 0: + # If we end up here, it means the output is most likely not complete and was truncated. + logging.warning(f"Skipping page {page_num} due to repetitions.") + predictions.append(f"\n\n[MISSING_PAGE_FAIL:{page_num}]\n\n") + else: + # If we end up here, it means the document page is too different from the training domain. + # This can happen e.g. for cover pages. + predictions.append(f"\n\n[MISSING_PAGE_EMPTY:{i * self.batchsize + j + 1}]\n\n") + else: + if self.markdown: + output = markdown_compatible(output) + predictions.append(output) + + if is_last_page[j]: + documents.append(predictions) + + predictions = [] + page_num = 0 + file_index += 1 + + # clear the torch cache and memory + self.empty_cache() + + return documents + + def predict(self, file_path: str, verbose=True) -> List[str]: + with open(file_path, "rb") as file: + pdfbin = file.read() + pdf = pypdfium2.PdfDocument(pdfbin) + pages = list(range(len(pdf))) + + compute_pages = pages.copy() + images = rasterize_paper(pdf, pages=compute_pages) + + dataset = ImageDataset( + images, + partial(self.model.encoder.prepare_input, random_padding=False), + ) + + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=self.batch_size, + pin_memory=True, + shuffle=False, + ) + + # clear the torch cache and memory + self.empty_cache() + + predictions = [""] * len(pages) + for idx, sample in tqdm(enumerate(dataloader), total=len(dataloader), disable=not verbose): + if sample is None: + continue + + model_output = self.model.inference(image_tensors=sample, early_stopping=self.skipping) + + for page_idx, output in enumerate(model_output["predictions"]): + if model_output["repeats"][page_idx] is not None: + if model_output["repeats"][page_idx] > 0: + disclaimer = "\n\n%s\n\n" + else: + disclaimer = "\n\n%s\n\n" + + rest = close_envs(model_output["repetitions"][page_idx]).strip() + if len(rest) > 0: + disclaimer = disclaimer % rest + else: + disclaimer = "" + else: + disclaimer = "" + + predictions[pages.index(compute_pages[idx * self.batch_size + page_idx])] = ( + markdown_compatible(output) + disclaimer + ) + + self.empty_cache() + gc.collect(generation=2) + return predictions + + def empty_cache(self): + if torch.cuda.is_available(): + torch.cuda.empty_cache() + elif torch.backends.mps.is_available(): + torch.mps.empty_cache() + + def process_outputs(self, model_output): + predictions = [""] * len(model_output["predictions"]) + for page_idx, output in enumerate(model_output["predictions"]): + if model_output["repeats"][page_idx] is not None: + if model_output["repeats"][page_idx] > 0: + disclaimer = "\n\n+++ ==WARNING: Truncated because of repetitions==\n%s\n+++\n\n" + else: + disclaimer = "\n\n+++ ==ERROR: No output for this page==\n%s\n+++\n\n" + + rest = close_envs(model_output["repetitions"][page_idx]).strip() + if len(rest) > 0: + disclaimer = disclaimer % rest + else: + disclaimer = "" + else: + disclaimer = "" + + predictions[page_idx] = markdown_compatible(output) + disclaimer + return predictions diff --git a/argilla-v1/src/extralit_v1/schema/__init__.py b/argilla-v1/src/extralit_v1/schema/__init__.py new file mode 100644 index 000000000..1599f0f5b --- /dev/null +++ b/argilla-v1/src/extralit_v1/schema/__init__.py @@ -0,0 +1,6 @@ +from .checks import register_check_methods +from .dtypes.parse import stage_for_validate + +register_check_methods() + +__all__ = ["stage_for_validate"] diff --git a/argilla-v1/src/extralit_v1/schema/checks/__init__.py b/argilla-v1/src/extralit_v1/schema/checks/__init__.py new file mode 100644 index 000000000..9c2a8e864 --- /dev/null +++ b/argilla-v1/src/extralit_v1/schema/checks/__init__.py @@ -0,0 +1,64 @@ +import pandas as pd +import pandera as pa +from pandera.extensions import register_check_method + +from .consistency import check_less_than, check_greater_than, check_between +from .multilabels import is_valid_list_str, multiselect +from .suggestion import suggestion +from .time_elapsed import check_time_difference +from .dataframe import singleton + + +def register_check_methods() -> None: + """ + Register Pandera check methods for various check functions, ensuring no duplicate in registered check names. + """ + if check_less_than.__name__ not in pa.Check: + register_check_method( + statistics=["columns_a", "columns_b", "or_equal"], supported_types=pd.DataFrame, check_type="vectorized" + )(check_less_than) + + if check_greater_than.__name__ not in pa.Check: + register_check_method( + statistics=["columns_a", "columns_b", "or_equal"], supported_types=pd.DataFrame, check_type="vectorized" + )(check_greater_than) + + if check_between.__name__ not in pa.Check: + register_check_method( + statistics=["columns_target", "columns_lower", "columns_upper", "or_equal"], + supported_types=pd.DataFrame, + check_type="vectorized", + )(check_between) + + if check_greater_than.__name__ not in pa.Check: + register_check_method(statistics=["col_a", "col_b", "or_equal"])(check_greater_than) + + if multiselect.__name__ not in pa.Check: + register_check_method(statistics=["delimiter", "isin"])(multiselect) + + if suggestion.__name__ not in pa.Check: + register_check_method(statistics=["values"])(suggestion) + + if check_time_difference.__name__ not in pa.Check: + register_check_method( + statistics=["field", "start_year", "start_month", "end_year", "end_month", "unit", "margin"], + check_type="vectorized", + )(check_time_difference) + + if singleton.__name__ not in pa.Check: + register_check_method(statistics=["enabled"], check_type="vectorized")(singleton) + + +register_check_methods() + +__all__ = [ + "is_valid_list_str", + "multiselect", + "suggestion", + "check_time_difference", + "check_less_than", + "check_greater_than", + "check_between", + "singleton", + "register_check_methods", +] diff --git a/argilla-v1/src/extralit_v1/schema/checks/consistency.py b/argilla-v1/src/extralit_v1/schema/checks/consistency.py new file mode 100644 index 000000000..b5eaf1216 --- /dev/null +++ b/argilla-v1/src/extralit_v1/schema/checks/consistency.py @@ -0,0 +1,119 @@ +from typing import Union, List + +import pandas as pd + +from extralit_v1.schema.checks.utils import make_same_length_arguments + + +def check_less_than( + df: pd.DataFrame, + *, + columns_a: Union[str, List[str]], + columns_b: Union[str, List[str]], + or_equal: Union[bool, List[str]] = False, +) -> pd.Series: + """ + Check if the values of columns in `col_a` are less than the values of columns in `col_b`. + + Args: + df: pd.DataFrame, required + The DataFrame to check. + columns_a: str or List[str], required + The column name or list of column names to compare. + columns_b: str or List[str], required + The column name or list of column names to compare. + or_equal: bool or List[bool], default=False + If True, the comparison will be inclusive. + """ + columns_a, columns_b, or_equal = make_same_length_arguments(columns_a, columns_b, or_equal) + assert ( + len(columns_a) == len(columns_b) == len(or_equal) + ), f"Input lists must have the same length, given {len(columns_a)}, {len(columns_b)}, {len(or_equal)}" + + checks = pd.Series([True] * len(df), index=df.index) + for a, b, oe in zip(columns_a, columns_b, or_equal): + if oe: + check = df[a] <= df[b] + else: + check = df[a] < df[b] + checks = checks & check.fillna(True) + + return checks + + +def check_greater_than( + df: pd.DataFrame, + *, + columns_a: Union[str, List[str]], + columns_b: Union[str, List[str]], + or_equal: Union[bool, List[str]] = False, +) -> pd.Series: + """ + Check if the values of columns in `col_a` are greater than the values of columns in `col_b`. + + Args: + df: pd.DataFrame, required + The DataFrame to check. + columns_a: str or List[str], required + The column name or list of column names to compare. + columns_b: str or List[str], required + The column name or list of column names to compare. + or_equal: bool or List[bool], default=False + If True, the comparison will be inclusive. + + """ + columns_a, columns_b, or_equal = make_same_length_arguments(columns_a, columns_b, or_equal) + assert ( + len(columns_a) == len(columns_b) == len(or_equal) + ), f"Input lists must have the same length, given {len(columns_a)}, {len(columns_b)}, {len(or_equal)}" + + checks = pd.Series([True] * len(df), index=df.index) + for a, b, oe in zip(columns_a, columns_b, or_equal): + if oe: + check = df[a] >= df[b] + else: + check = df[a] > df[b] + checks = checks & check.fillna(True) + return checks + + +def check_between( + df: pd.DataFrame, + *, + columns_target: Union[str, List[str]], + columns_lower: Union[str, List[str]], + columns_upper: Union[str, List[str]], + or_equal: Union[bool, List[str]] = True, +) -> pd.Series: + """ + Check if the values of columns in `col_a` are between the values of columns in `col_b` and `col_c`. + + Args: + df: pd.DataFrame, required + The DataFrame to check. + columns_target: str or List[str], required + The column name or list of column names to compare. + columns_lower: str or List[str], required + The column name or list of column names to compare with `column` to be lower. + columns_upper: str or List[str], required + The column name or list of column names to compare. + or_equal: bool or List[bool], default=False + If True, the comparison will be inclusive. + + """ + columns_target, columns_lower, columns_upper, or_equal = make_same_length_arguments( + columns_target, columns_lower, columns_upper, or_equal + ) + assert ( + len(columns_target) == len(columns_lower) == len(columns_upper) == len(or_equal) + ), f"Input lists must have the same length, given {len(columns_target)}, {len(columns_lower)}, {len(columns_upper)}, {len(or_equal)}" + + checks = pd.Series([True] * len(df), index=df.index) + for col, lower, upper, oe in zip(columns_target, columns_lower, columns_upper, or_equal): + if oe: + check = (df[col] >= df[lower]) & (df[col] <= df[upper]) + else: + check = (df[col] > df[lower]) & (df[col] < df[upper]) + checks = checks & check.fillna(True) + + return checks diff --git a/argilla-v1/src/extralit_v1/schema/checks/dataframe.py b/argilla-v1/src/extralit_v1/schema/checks/dataframe.py new file mode 100644 index 000000000..9c05b061f --- /dev/null +++ b/argilla-v1/src/extralit_v1/schema/checks/dataframe.py @@ -0,0 +1,9 @@ +from typing import Union, List +import pandas as pd + + +def singleton(df: pd.DataFrame, *, enabled: bool = True) -> bool: + if not enabled: + return True + + return df.index.is_unique diff --git a/argilla-v1/src/extralit_v1/schema/checks/join.py b/argilla-v1/src/extralit_v1/schema/checks/join.py new file mode 100644 index 000000000..55b695ada --- /dev/null +++ b/argilla-v1/src/extralit_v1/schema/checks/join.py @@ -0,0 +1,19 @@ +import pandas as pd + + +def unmatched_join_keys(df: pd.DataFrame, other: pd.DataFrame, key: str) -> pd.Series: + """ + Count the number of unmatched keys in the `df` DataFrame compared to the `other` DataFrame. + """ + if key in df.columns: + merged_df = df.merge(other, on=key, how="left", indicator=True) + a_unmatched_keys = merged_df[merged_df["_merge"] == "left_only"][key] + + else: + merged_df = df.merge( + other, left_index=df.index.name == key, right_index=other.index.name == key, how="left", indicator=True + ) + + a_unmatched_keys = merged_df[merged_df["_merge"] == "left_only"].index + + return a_unmatched_keys.drop_duplicates() diff --git a/argilla-v1/src/extralit_v1/schema/checks/multilabels.py b/argilla-v1/src/extralit_v1/schema/checks/multilabels.py new file mode 100644 index 000000000..05283883f --- /dev/null +++ b/argilla-v1/src/extralit_v1/schema/checks/multilabels.py @@ -0,0 +1,21 @@ +from typing import List + +import pandas as pd + + +def is_valid_list_str(values: List[str]): + if not isinstance(values, list): + return True + + return all(isinstance(val, str) and val and val.strip() == val and not val.startswith("and ") for val in values) + + +def multiselect(series: pd.Series, *, delimiter=",", isin: List[str] = None): + """Check that the values in the series are valid lists of strings.""" + split_values = series.str.split(r"\s*" + delimiter + r"\s*", regex=True) + checks = split_values.apply(is_valid_list_str) + + if isinstance(isin, (set, list)) and isin: + checks = checks & split_values.apply(lambda x: all(is_valid_list_str(x) and set(x).issubset(isin))) + + return checks diff --git a/argilla-v1/src/extralit_v1/schema/checks/suggestion.py b/argilla-v1/src/extralit_v1/schema/checks/suggestion.py new file mode 100644 index 000000000..2129f7049 --- /dev/null +++ b/argilla-v1/src/extralit_v1/schema/checks/suggestion.py @@ -0,0 +1,15 @@ +import logging +from typing import Dict, Optional + +import pandas as pd + +_LOGGER = logging.getLogger(__name__) + + +def suggestion(series: pd.Series, values: Dict[str, Optional[Dict[str, str]]]): + mask = series.isin(values) | series.isna() + if not mask.all(): + print(f"INFO: Some `{series.name}` values were not in the suggested values: {series[~mask].unique()}") + _LOGGER.info(f"Some `{series.name}` values were not in the suggested values: {series[~mask].unique()}") + + return True diff --git a/argilla-v1/src/extralit_v1/schema/checks/time_elapsed.py b/argilla-v1/src/extralit_v1/schema/checks/time_elapsed.py new file mode 100644 index 000000000..9deb60cdb --- /dev/null +++ b/argilla-v1/src/extralit_v1/schema/checks/time_elapsed.py @@ -0,0 +1,34 @@ +import pandas as pd +from pandera.typing import DataFrame + + +def check_time_difference( + df: DataFrame, + *, + field: str = None, + start_year: str, + start_month: str, + end_year: str, + end_month: str, + unit: str = "months", + margin: float = 1, +): + """ + Check if the `field` column correctly represents the time difference + between start_year (and start_month) and end_year (and end_month) in months. + """ + start_dates = pd.to_datetime( + df[start_year].astype(str) + "-" + df[start_month].astype(str), format="%Y-%m", yearfirst=True, errors="coerce" + ) + end_dates = pd.to_datetime( + df[end_year].astype(str) + "-" + df[end_month].astype(str), format="%Y-%m", yearfirst=True, errors="coerce" + ) + calculated_time_elapsed = (end_dates - start_dates).dt.days + + if unit == "months": + calculated_time_elapsed /= 30.44 # Average days per month + + # Allow a small margin of error due to average days per month approximation + checks = df[field].isna() | calculated_time_elapsed.isna() | ((calculated_time_elapsed - df[field]).abs() <= margin) + + return checks diff --git a/argilla-v1/src/extralit_v1/schema/checks/utils.py b/argilla-v1/src/extralit_v1/schema/checks/utils.py new file mode 100644 index 000000000..db6a19626 --- /dev/null +++ b/argilla-v1/src/extralit_v1/schema/checks/utils.py @@ -0,0 +1,16 @@ +from typing import Tuple, List + + +def make_same_length_arguments(*args, **kwargs) -> Tuple[List, ...]: + """Ensure all arguments are lists of the same length.""" + all_args = list(args) + list(kwargs.values()) + max_len = max(len(arg) if isinstance(arg, list) else 1 for arg in all_args) + + result = [] + for arg in all_args: + if not isinstance(arg, (list, tuple)): + result.append([arg] * max_len) + else: + result.append(arg) + + return tuple(result) diff --git a/argilla/src/extralit/schema/dtypes/__init__.py b/argilla-v1/src/extralit_v1/schema/dtypes/__init__.py similarity index 100% rename from argilla/src/extralit/schema/dtypes/__init__.py rename to argilla-v1/src/extralit_v1/schema/dtypes/__init__.py diff --git a/argilla-v1/src/extralit_v1/schema/dtypes/parse.py b/argilla-v1/src/extralit_v1/schema/dtypes/parse.py new file mode 100644 index 000000000..94ffa8324 --- /dev/null +++ b/argilla-v1/src/extralit_v1/schema/dtypes/parse.py @@ -0,0 +1,66 @@ +import warnings +from typing import Dict, Any, Union + +import pandas as pd +import pandera as pa + + +def check_data_types(schema: Union[pa.DataFrameModel, pa.DataFrameSchema]) -> Dict[str, str]: + dataframe_schema = schema.to_schema() if hasattr(schema, "to_schema") else schema + dtypes_dict = dataframe_schema.dtypes + type_classification = {} + + for column, pa_dtype in dtypes_dict.items(): + dtype = str(pa_dtype) + + if dtype in [int, "int", "int32", "int64", "Int32", "Int64", float, "float", "float32", "float64"]: + type_classification[column] = "numeric" + elif dtype in [str, "str", "object"]: + type_classification[column] = "string" + else: + type_classification[column] = "unknown" + warnings.warn(f"Unknown data type for column {column}: {dtype}") + + return type_classification + + +def replace_na_values( + df: pd.DataFrame, schema: Union[pa.DataFrameModel, pa.DataFrameSchema], to_replace: Dict[str, Dict[Any, Any]] +) -> pd.DataFrame: + dtype_to_replace = {} + + for col, dclass in check_data_types(schema).items(): + if to_replace and col in to_replace: + dtype_to_replace[col] = to_replace[col] + elif dclass == "numeric": + dtype_to_replace[col] = {"NA": 0, "nan": 0, "NaN": 0, "None": 0, "": 0} + elif dclass == "string": + dtype_to_replace[col] = {"NA": None, "": None} + else: + # If the data type is unknown, we don't want to replace anything + pass + + replaced_na_df = df.replace(dtype_to_replace) + return replaced_na_df + + +def stage_for_validate( + df: pd.DataFrame, + schema: Union[pa.DataFrameModel, pa.DataFrameSchema], + to_replace: Dict[str, Dict] = None, + prefix_index_name: str = "publication_ref", +) -> pd.DataFrame: + """ + Replace NA values in a dataframe with a value that is appropriate for the data type of the column. + + :param df: Pandas dataframe + :param schema: Pandera dataframe itnrecal + :param to_replace: Dictionary of values to replace. Keys are column names, values are dictionaries of values to replace. + :return: Pandas dataframe with NA values replaced + """ + dataframe_schema = schema.to_schema() if hasattr(schema, "to_schema") else schema + + # df = prepend_reference_to_index_level(df, dataframe_schema, prefix_index_name) + df = replace_na_values(df, dataframe_schema, to_replace) + + return df diff --git a/argilla/src/extralit/schema/references/__init__.py b/argilla-v1/src/extralit_v1/schema/references/__init__.py similarity index 100% rename from argilla/src/extralit/schema/references/__init__.py rename to argilla-v1/src/extralit_v1/schema/references/__init__.py diff --git a/argilla-v1/src/extralit_v1/schema/references/assign.py b/argilla-v1/src/extralit_v1/schema/references/assign.py new file mode 100644 index 000000000..b1f1a6580 --- /dev/null +++ b/argilla-v1/src/extralit_v1/schema/references/assign.py @@ -0,0 +1,131 @@ +from typing import List, Union, Optional + +import pandas as pd +import pandera as pa +from rapidfuzz import process, fuzz + + +def get_unique_index(df: pd.DataFrame, group_columns: List[str], prefix="", suffix="", n_digits=2) -> pd.DataFrame: + if not isinstance(group_columns, list): + group_columns = list(group_columns) + + non_na_cols = df.dropna(axis=1, how="all").columns + group_columns = [col for col in group_columns if col in non_na_cols] + + counter_series = enumerate_group_id(df, group_columns) + counter_ids = counter_series.map(lambda x: f"{prefix}{int(x):0{n_digits}}{suffix}").sort_index() + + assert counter_ids.index.size == df.index.size + + return counter_ids + + +def enumerate_group_id(df: pd.DataFrame, group_columns: List[str], start=1) -> pd.Series: + non_na_cols = df.dropna(axis=1, how="all").columns + valid_group_columns = [col for col in group_columns if col in non_na_cols] + if not valid_group_columns: + return pd.Series(index=df.index + start if df.index.dtype == int else df.index, dtype=str).fillna("0") + + # Create a string representation of each row for the valid group columns + group_strs = df[valid_group_columns].astype(str).apply(lambda row: "-".join(row), axis=1) + + # Create a dictionary that maps each unique group string to a unique number, in the order they first appear + group_to_number = {group_str: i for i, group_str in enumerate(pd.unique(group_strs), start=start)} + + # Map each group string in group_strs to its corresponding number + ngroup = group_strs.map(group_to_number) + + return ngroup + + +def get_prefix(schema: pa.DataFrameSchema): + if schema.index is not None: + if schema.index.checks: + if "string" in schema.index.checks[0].statistics: + return schema.index.checks[0].statistics["string"] + + return schema.name[0] + + +def assign_unique_index( + df: pd.DataFrame, + schema: pa.DataFrameSchema, + index_name: str = "reference", + prefix: Optional[str] = "", + suffix: str = "", + n_digits=2, + drop_duplicates=False, +) -> pd.DataFrame: + """ + Assign unique reference keys to each entity in `df` of the `model` schema, by enumerating the unique entities. + + Args: + df: pandas.DataFrame + schema: Model definition + index_name: Index name, default 'reference' + prefix: Prefix + suffix: Suffix + n_digits: number of digits to enumerate + drop_duplicates: Drop duplicate indices + + Returns: + pd.DataFrame + """ + # if df.index.name == name or name in df.index.names: + # return df + + index = pd.Series(range(1, len(df) + 1)).map(lambda x: f"{prefix}{int(x):0{n_digits}}{suffix}") + df.index = pd.Index(index, name=index_name) + + if drop_duplicates: + df = df[~df.index.duplicated()] + assert not df.index.isna().any() + return df + + +def map_items_to_references( + items: Union[pd.Series, pd.DataFrame], reference_items: pd.DataFrame, threshold: float = 80 +) -> List[str]: + """ + Map `items` to `reference_items` using fuzzy matching. + Args: + items: pd.Series + A series of items in dict to be mapped to `reference_items` + reference_items: pd.DataFrame + A DataFrame with reference items + + Returns: + List of indices of `reference_items` that best match each item in items + """ + if isinstance(items, pd.Series): + assert (items.map(type) == dict).all(), f"given items must be a pd.Series with dict values" + items_df: pd.DataFrame = items.apply(pd.Series) + dtypes = reference_items.dtypes.drop(reference_items.columns.difference(items_df.columns)) + items_df = items_df.astype(dtypes, errors="ignore") + elif isinstance(items, pd.DataFrame): + items_df = items + else: + raise ValueError(f"items must be a pd.Series or pd.DataFrame, given {type(items)}") + + joint_columns = items_df.columns.intersection(reference_items.columns) + + reference_mapping = [None for _ in range(len(items))] + for idx, item_row in items_df.iterrows(): + # Concatenate item values into a single string for comparison, considering only the columns present in items_df + item_str = " ".join([str(item_row[col]) for col in joint_columns if item_row[col] is not None]) + + # Concatenate each row in reference_items into a single string, considering the same subset of columns + reference_strs = reference_items[joint_columns].apply(lambda row: " ".join(row.astype(str)), axis=1) + + if not item_str.strip(): + continue + + # Find the best match for item_str in reference_strs using fuzzy search + best_match, score, _ = process.extractOne(item_str, reference_strs, scorer=fuzz.WRatio) + + # Consider a match if the score is above a certain threshold + if score > threshold: + reference_index = reference_strs[reference_strs == best_match].index[0] + reference_mapping[idx] = reference_index + + return reference_mapping diff --git a/argilla-v1/src/extralit_v1/schema/registry.py b/argilla-v1/src/extralit_v1/schema/registry.py new file mode 100644 index 000000000..8bf160d00 --- /dev/null +++ b/argilla-v1/src/extralit_v1/schema/registry.py @@ -0,0 +1,17 @@ +import os +from typing import Dict + +import pandera as pa +from pandera.io import from_yaml + + +def load_schemas(dir_path: str) -> Dict[str, pa.DataFrameSchema]: + dataframe_models = {} + for filename in os.listdir(dir_path): + if not filename.endswith(".yaml") or not os.path.exists(os.path.join(dir_path, filename)): + continue + + schema: pa.DataFrameSchema = from_yaml(os.path.join(dir_path, filename)) + dataframe_models[schema.name] = schema + + return dataframe_models diff --git a/argilla/src/extralit/server/__init__.py b/argilla-v1/src/extralit_v1/server/__init__.py similarity index 100% rename from argilla/src/extralit/server/__init__.py rename to argilla-v1/src/extralit_v1/server/__init__.py diff --git a/argilla-v1/src/extralit_v1/server/app.py b/argilla-v1/src/extralit_v1/server/app.py new file mode 100644 index 000000000..5f69693ab --- /dev/null +++ b/argilla-v1/src/extralit_v1/server/app.py @@ -0,0 +1,280 @@ +import logging +from typing import Optional, Union, List, Literal +from uuid import UUID + +import pandas as pd +from fastapi import FastAPI, Depends, Body, Query, status, HTTPException +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import StreamingResponse +from langfuse.llama_index import LlamaIndexCallbackHandler +from langfuse.model import ChatPromptClient +from langfuse.utils.base_callback_handler import LangfuseBaseCallbackHandler +from llama_index.core.chat_engine.types import ChatMode +from llama_index.core.vector_stores import MetadataFilters, MetadataFilter, FilterOperator +from minio import Minio +from weaviate import WeaviateClient + +import argilla_v1 as rg +from extralit_v1.convert.json_table import json_to_df +from extralit_v1.extraction.extraction import extract_schema +from extralit_v1.extraction.models.paper import PaperExtraction +from extralit_v1.extraction.models.schema import SchemaStructure +from extralit_v1.extraction.prompts import DEFAULT_CHAT_PROMPT_TMPL, CHAT_SYSTEM_PROMPT +from extralit_v1.extraction.query import get_nodes_metadata, vectordb_contains_any +from extralit_v1.extraction.vector_index import create_vector_index, load_index +from extralit_v1.server.context.files import get_minio_client +from extralit_v1.server.context.llamaindex import get_langfuse_callback +from extralit_v1.server.context.vectordb import get_weaviate_client +from extralit_v1.server.models.extraction import ExtractionRequest, ExtractionResponse +from extralit_v1.server.models.segments import SegmentsResponse + +_LOGGER = logging.getLogger(__name__) +app = FastAPI() + +# app.add_middleware( +# CORSMiddleware, +# allow_origins=["http://argilla-server"], +# allow_credentials=True, +# allow_methods=["*"], +# allow_headers=["*"], +# ) + +weaviate_client: WeaviateClient = None +minio_client: Minio = None + + +@app.on_event("startup") +async def load_weaviate_client(): + global weaviate_client + if weaviate_client is None: + weaviate_client = get_weaviate_client() + + +@app.on_event("startup") +async def load_minio_client(): + global minio_client + if minio_client is None: + minio_client = get_minio_client() + + +@app.get("/health") +async def health_check(): + return {"status": "ok"} + + +@app.get("/schemas/{workspace}") +async def schemas( + workspace: str = "itn-recalibration", +): + ss = SchemaStructure.from_s3(workspace_name=workspace, minio_client=minio_client) + return ss.ordering + + +@app.get("/chat", status_code=status.HTTP_200_OK, response_class=StreamingResponse) +async def chat( + query: str = Query(...), + workspace: str = Query(...), + reference: str = Query(...), + similarity_top_k: int = Query(5, alias="k"), + chat_mode: ChatMode = Query(ChatMode.BEST), + llm_model: str = Query("gpt-3.5-turbo"), + username: Optional[Union[str, UUID]] = None, + prompt_template: str = "chat", + langfuse_callback: Optional[LlamaIndexCallbackHandler] = Depends(get_langfuse_callback), +): + index = load_index( + paper=pd.Series(name=reference), + llm_model=llm_model, + embed_model="text-embedding-3-small", + weaviate_client=weaviate_client, + index_name="LlamaIndexDocumentSections", + ) + + if not vectordb_contains_any(reference, weaviate_client=weaviate_client, index_name="LlamaIndexDocumentSections"): + raise HTTPException(status_code=404, detail=f"No context found for reference: {reference}") + + try: + if isinstance(langfuse_callback, LlamaIndexCallbackHandler): + langfuse_callback.set_trace_params( + name=f"chat-{reference}", + user_id=username, + session_id=reference, + tags=[workspace, reference, "chat"], + ) + except Exception as e: + _LOGGER.error(f"Failed to set trace params: {e}") + + # Get the system prompt + try: + chat_prompts: ChatPromptClient = langfuse_callback.langfuse.get_prompt(prompt_template, cache_ttl_seconds=3000) + system_prompt = chat_prompts.prompt[0]["content"] + except Exception as e: + _LOGGER.error(f"Failed to get system prompt: {e}") + system_prompt = None + + filters = MetadataFilters( + filters=[MetadataFilter(key="reference", value=reference, operator=FilterOperator.EQ)], + ) + + query_engine = index.as_chat_engine( + chat_mode=chat_mode, + vector_store_query_mode="hybrid", + alpha=0.25, + similarity_top_k=similarity_top_k, + filters=filters, + system_prompt=system_prompt or CHAT_SYSTEM_PROMPT, + text_qa_template=DEFAULT_CHAT_PROMPT_TMPL, + ) + + response = query_engine.stream_chat(query) + return StreamingResponse(response.response_gen, media_type="text/event-stream") + + +@app.post("/extraction", status_code=status.HTTP_201_CREATED, response_model=ExtractionResponse) +async def extraction( + *, + extraction_request: ExtractionRequest = Body(...), + workspace: str = Query(...), + model: str = "gpt-4o", + similarity_top_k: int = 8, + username: Optional[Union[str, UUID]] = None, + prompt_template: str = "completion", + langfuse_callback: Optional[LlamaIndexCallbackHandler] = Depends(get_langfuse_callback), +): + schema_structure = SchemaStructure.from_s3(workspace_name=workspace, minio_client=minio_client) + schema = schema_structure[extraction_request.schema_name] + + extraction_dfs = {} + for schema_name, extraction_dict in extraction_request.extractions.items(): + schema = schema_structure[schema_name] + extraction_dfs[schema.name] = json_to_df(extraction_dict, schema=schema) + + extractions = PaperExtraction( + reference=extraction_request.reference, extractions=extraction_dfs, schemas=schema_structure + ) + + # Get the system prompt + try: + system_prompt = langfuse_callback.langfuse.get_prompt(prompt_template, cache_ttl_seconds=3000, max_retries=0) + except Exception as e: + _LOGGER.error(f"Failed to get system prompt: {e}") + system_prompt = None + + try: + if isinstance(langfuse_callback, LlamaIndexCallbackHandler): + langfuse_callback.set_trace_params( + name=f"extract-{extraction_request.reference}", + user_id=username, + session_id=extraction_request.reference, + tags=[workspace, extraction_request.reference, extraction_request.schema_name, "partial-extraction"], + ) + except Exception as e: + _LOGGER.error(f"Failed to set trace params: {e}") + + ### Create or load the index ### + try: + index = load_index( + paper=pd.Series(name=extraction_request.reference), + llm_model=model, + embed_model="text-embedding-3-small", + weaviate_client=weaviate_client, + index_name="LlamaIndexDocumentSections", + ) + except Exception as e: + _LOGGER.error(f"Failed to create or load the index: {e}") + raise HTTPException(status_code=500, detail=f"Failed to create an extraction request: {e}") + + if extraction_request.headers and len(extraction_request.headers) > similarity_top_k: + similarity_top_k = len(extraction_request.headers) + + try: + ### Extract entities ### + df, rag_response = extract_schema( + schema=schema, + extractions=extractions, + index=index, + include_fields=extraction_request.columns, + headers=extraction_request.headers, + types=extraction_request.types, + similarity_top_k=similarity_top_k, + system_prompt=system_prompt, + user_prompt=extraction_request.prompt, + vector_store_query_mode="hybrid", + ) + + if not isinstance(df, pd.DataFrame) or df.empty: + if rag_response.source_nodes is None or len(rag_response.source_nodes) == 0: + raise HTTPException( + status_code=404, + detail=f"There were no context selected due to stringent filters. Please modify your
" + f"filters: {dict(headers=extraction_request.headers, types=extraction_request.types)}", + ) + raise HTTPException(status_code=404, detail="No extraction found with the selected context and your query.") + + response = ExtractionResponse.parse_raw(df.to_json(orient="table")) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + if isinstance(langfuse_callback, LangfuseBaseCallbackHandler): + langfuse_callback.flush() + + return response + + +@app.get("/segments/", status_code=status.HTTP_200_OK, response_model=SegmentsResponse) +async def segments( + *, + workspace: str = Query(...), + reference: str = Query(...), + types: Optional[List[Literal["text", "table", "figure"]]] = Query(None), + username: Optional[Union[str, UUID]] = Query(None), + limit=100, +): + filters = [] + + if types: + filters.append(MetadataFilter(key="type", value=types, operator=FilterOperator.NE)) + + filters.append(MetadataFilter(key="reference", value=reference, operator=FilterOperator.EQ)) + + entries = get_nodes_metadata( + weaviate_client=weaviate_client, + filters=MetadataFilters(filters=filters), + limit=limit, + index_name="LlamaIndexDocumentSections", + ) + + return SegmentsResponse(items=entries) + + +@app.post("/index/", status_code=status.HTTP_201_CREATED) +async def create_index( + workspace: str = Query(...), + reference: str = Query(...), + preprocessing_dataset: str = Query(None), + embed_model: str = Query("text-embedding-3-small"), + username: Optional[Union[str, UUID]] = Query(None), +): + try: + preprocessing_dataset = ( + rg.FeedbackDataset.from_argilla(name=preprocessing_dataset, workspace=workspace) + if preprocessing_dataset + else None + ) + except Exception as e: + preprocessing_dataset = None + + try: + index = create_vector_index( + paper=pd.Series(name=reference), + weaviate_client=weaviate_client, + preprocessing_dataset=preprocessing_dataset, + preprocessing_path="data/preprocessing/nougat/", + index_name="LlamaIndexDocumentSections", + embed_model=embed_model, + ) + + return index.index_id + + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) diff --git a/argilla/src/extralit/server/context/__init__.py b/argilla-v1/src/extralit_v1/server/context/__init__.py similarity index 100% rename from argilla/src/extralit/server/context/__init__.py rename to argilla-v1/src/extralit_v1/server/context/__init__.py diff --git a/argilla-v1/src/extralit_v1/server/context/datasets.py b/argilla-v1/src/extralit_v1/server/context/datasets.py new file mode 100644 index 000000000..bff89da30 --- /dev/null +++ b/argilla-v1/src/extralit_v1/server/context/datasets.py @@ -0,0 +1,23 @@ +import os +from typing import Optional + +import argilla_v1 as rg +from argilla.client.feedback.dataset.remote.dataset import RemoteFeedbackDataset +from argilla.client.sdk.commons.errors import UnauthorizedApiError + + +def get_argilla_dataset( + dataset_name="Table-Preprocessing", workspace_name="itn-recalibration" +) -> RemoteFeedbackDataset: + try: + rg.init( + api_url=os.getenv("ARGILLA_BASE_URL"), + api_key=os.getenv("ARGILLA_API_KEY"), + workspace="argilla", + ) + except Exception as e: + print(e) + + dataset = rg.FeedbackDataset.from_argilla(name=dataset_name, workspace=workspace_name, with_documents=False) + + return dataset diff --git a/argilla/src/extralit/server/context/files.py b/argilla-v1/src/extralit_v1/server/context/files.py similarity index 65% rename from argilla/src/extralit/server/context/files.py rename to argilla-v1/src/extralit_v1/server/context/files.py index 5061f0e33..66381842e 100644 --- a/argilla/src/extralit/server/context/files.py +++ b/argilla-v1/src/extralit_v1/server/context/files.py @@ -8,9 +8,9 @@ def get_minio_client() -> Optional[Minio]: - s3_endpoint = os.getenv('S3_ENDPOINT') - s3_access_key = os.getenv('S3_ACCESS_KEY') - s3_secret_key = os.getenv('S3_SECRET_KEY') + s3_endpoint = os.getenv("S3_ENDPOINT") + s3_access_key = os.getenv("S3_ACCESS_KEY") + s3_secret_key = os.getenv("S3_SECRET_KEY") if s3_endpoint is None: return None @@ -21,11 +21,13 @@ def get_minio_client() -> Optional[Minio]: port = parsed_url.port if hostname is None: - _LOGGER.error(f"Invalid URL: no hostname in S3_ENDPOINT found, possible due to lacking http(s) protocol. Given '{s3_endpoint}'") + _LOGGER.error( + f"Invalid URL: no hostname in S3_ENDPOINT found, possible due to lacking http(s) protocol. Given '{s3_endpoint}'" + ) return None return Minio( - endpoint=f'{hostname}:{port}' if port else hostname, + endpoint=f"{hostname}:{port}" if port else hostname, access_key=s3_access_key, secret_key=s3_secret_key, secure=parsed_url.scheme == "https", diff --git a/argilla-v1/src/extralit_v1/server/context/llamaindex.py b/argilla-v1/src/extralit_v1/server/context/llamaindex.py new file mode 100644 index 000000000..7ab399458 --- /dev/null +++ b/argilla-v1/src/extralit_v1/server/context/llamaindex.py @@ -0,0 +1,28 @@ +import logging +import os +from typing import Optional, Union + +from langfuse.llama_index import LlamaIndexCallbackHandler +from langfuse.utils.base_callback_handler import LangfuseBaseCallbackHandler +from llama_index.core import Settings, set_global_handler + +_LOGGER = logging.getLogger(__name__) + + +def get_langfuse_callback( + langfuse_public_key: Optional[str] = None, langfuse_secret_key: Optional[str] = None +) -> Union[LangfuseBaseCallbackHandler, LlamaIndexCallbackHandler]: + try: + langfuse_callback_handler = LlamaIndexCallbackHandler( + host=os.getenv("LANGFUSE_HOST"), + public_key=langfuse_public_key if langfuse_public_key else os.getenv("LANGFUSE_PUBLIC_KEY"), + secret_key=langfuse_secret_key if langfuse_secret_key else os.getenv("LANGFUSE_SECRET_KEY"), + ) + if not Settings.callback_manager.handlers: + Settings.callback_manager.add_handler(langfuse_callback_handler) + set_global_handler("langfuse") + except Exception as e: + _LOGGER.error(f"Failed to create Langfuse callback handler: {e}") + langfuse_callback_handler = None + + return langfuse_callback_handler diff --git a/argilla-v1/src/extralit_v1/server/context/vectordb.py b/argilla-v1/src/extralit_v1/server/context/vectordb.py new file mode 100644 index 000000000..9dc8f6210 --- /dev/null +++ b/argilla-v1/src/extralit_v1/server/context/vectordb.py @@ -0,0 +1,47 @@ +from typing import Optional +import os +from urllib.parse import urlparse + +import weaviate +from weaviate import WeaviateClient +from weaviate.exceptions import WeaviateStartUpError + + +def get_weaviate_client(http_port=None, http_secure=None, grpc_port=None, grpc_secure=None) -> Optional[WeaviateClient]: + if "WCS_HTTP_URL" not in os.environ: + print("WCS_HTTP_URL not set") + return None + + try: + api_keys = os.getenv("WCS_API_KEY", "").split(",") + + # Parse HTTP URL + http_url = urlparse(os.getenv("WCS_HTTP_URL")) + http_port = http_port or http_url.port + http_secure = http_secure or http_url.scheme == "https" + + # Parse GRPC URL + grpc_url = urlparse(os.getenv("WCS_GRPC_URL")) + grpc_port = grpc_port or grpc_url.port + grpc_secure = grpc_secure or grpc_url.scheme == "https" + + weaviate_client = weaviate.connect_to_custom( + http_host=http_url.hostname, + http_port=http_port, + http_secure=http_secure, + grpc_host=grpc_url.hostname, + grpc_port=grpc_port, + grpc_secure=grpc_secure, + auth_credentials=weaviate.auth.AuthApiKey(api_keys[0]), + headers={"X-OpenAI-Api-Key": os.environ["OPENAI_API_KEY"]} if "OPENAI_API_KEY" in os.environ else None, + ) + + return weaviate_client + + except WeaviateStartUpError as wsue: + print(f"Failed to start Weaviate: {wsue}") + + except Exception as e: + raise e + + return None diff --git a/argilla/src/extralit/server/models/__init__.py b/argilla-v1/src/extralit_v1/server/models/__init__.py similarity index 100% rename from argilla/src/extralit/server/models/__init__.py rename to argilla-v1/src/extralit_v1/server/models/__init__.py diff --git a/argilla-v1/src/extralit_v1/server/models/extraction.py b/argilla-v1/src/extralit_v1/server/models/extraction.py new file mode 100644 index 000000000..8a5c31e84 --- /dev/null +++ b/argilla-v1/src/extralit_v1/server/models/extraction.py @@ -0,0 +1,39 @@ +from typing import Dict, Optional, List, Any, Union, Annotated, Any +from pydantic import BaseModel, Field, Extra + +SchemaName = Annotated[str, Field(description="The schema name of the extraction.", examples=["schema_name"])] +FieldName = Annotated[str, Field(description="The name of the field.", examples=["field_name"])] +Value = Annotated[Any, Field(description="The value of the field.", examples=["value"])] + +Data = List[Dict[FieldName, Value]] +Extractions = Dict[SchemaName, Data] + + +class ExtractionRequest(BaseModel): + reference: str + schema_name: str + extractions: Extractions = Field(default_factory=dict, description="All previously extracted data.") + columns: Optional[List[str]] = None + headers: Optional[List[str]] = None + types: Optional[List[str]] = None + prompt: Optional[str] = None + + +class FieldSchema(BaseModel): + name: str + type: Optional[str] = None + extDtype: Optional[str] = None + + +class Schema(BaseModel): + fields: List[FieldSchema] + primaryKey: Optional[List[str]] = None + pandas_version: Optional[str] + + +class ExtractionResponse(BaseModel): + schema: Schema + data: Data + + class Config: + extra = Extra.ignore diff --git a/argilla-v1/src/extralit_v1/server/models/segments.py b/argilla-v1/src/extralit_v1/server/models/segments.py new file mode 100644 index 000000000..1124ad3ad --- /dev/null +++ b/argilla-v1/src/extralit_v1/server/models/segments.py @@ -0,0 +1,18 @@ +from typing import Dict, Optional, List, Any, Union, Annotated, Any +from uuid import UUID + +from pydantic import BaseModel, Field, Extra + + +class SegmentResponse(BaseModel): + doc_id: str | UUID + header: str | None + page_number: int | None + type: str | None = Field(None, description="The type of the segment.") + + class Config: + extra = Extra.ignore + + +class SegmentsResponse(BaseModel): + items: List[SegmentResponse] diff --git a/argilla/src/extralit/storage/__init__.py b/argilla-v1/src/extralit_v1/storage/__init__.py similarity index 100% rename from argilla/src/extralit/storage/__init__.py rename to argilla-v1/src/extralit_v1/storage/__init__.py diff --git a/argilla-v1/src/extralit_v1/storage/files.py b/argilla-v1/src/extralit_v1/storage/files.py new file mode 100644 index 000000000..4d42ca438 --- /dev/null +++ b/argilla-v1/src/extralit_v1/storage/files.py @@ -0,0 +1,116 @@ +from enum import Enum +import os +import json +import dill +import pandas as pd +from typing import Optional, Tuple +from minio import Minio +from minio.error import S3Error +import fsspec + +from extralit_v1.server.context.files import get_minio_client + + +class StorageType(Enum): + FILE = "file" + S3 = "s3" + + +class FileHandler: + def __init__(self, base_path: str, storage_type: str = StorageType.FILE, bucket_name: Optional[str] = None): + self.base_path = base_path + self.storage_type = storage_type + self.bucket_name = bucket_name + + if storage_type == StorageType.S3: + assert bucket_name is not None + self.client = get_minio_client() + + def _get_full_path(self, path: str) -> str: + return os.path.join(self.base_path, path) + + def exists(self, path: str) -> bool: + full_path = self._get_full_path(path) + if self.storage_type == StorageType.FILE: + return os.path.exists(full_path) + + elif self.storage_type == StorageType.S3: + try: + self.client.stat_object(self.bucket_name, full_path) + return True + except S3Error: + return False + + def read_json(self, path: str) -> dict: + full_path = self._get_full_path(path) + if self.storage_type == StorageType.FILE: + with open(full_path, "r") as file: + return json.load(file) + + elif self.storage_type == StorageType.S3: + response = self.client.get_object(self.bucket_name, full_path) + return json.loads(response.read().decode("utf-8")) + + def write_json(self, path: str, data: dict): + full_path = self._get_full_path(path) + if self.storage_type == StorageType.FILE: + os.makedirs(os.path.dirname(full_path), exist_ok=True) + with open(full_path, "w") as file: + json.dump(data, file) + + elif self.storage_type == StorageType.S3: + self.client.put_object(self.bucket_name, full_path, json.dumps(data).encode("utf-8"), len(json.dumps(data))) + + def read_dill(self, path: str): + full_path = self._get_full_path(path) + if self.storage_type == StorageType.FILE: + with open(full_path, "rb") as file: + return dill.load(file) + + elif self.storage_type == StorageType.S3: + response = self.client.get_object(self.bucket_name, full_path) + return dill.loads(response.read()) + + def write_dill(self, path: str, data): + full_path = self._get_full_path(path) + if self.storage_type == StorageType.FILE: + os.makedirs(os.path.dirname(full_path), exist_ok=True) + with open(full_path, "wb") as file: + dill.dump(data, file) + + elif self.storage_type == StorageType.S3: + self.client.put_object(self.bucket_name, full_path, dill.dumps(data), len(dill.dumps(data))) + + def read_text(self, path: str) -> str: + full_path = self._get_full_path(path) + if self.storage_type == StorageType.FILE: + with open(full_path, "r") as file: + return file.read() + + elif self.storage_type == StorageType.S3: + response = self.client.get_object(self.bucket_name, full_path) + return response.read().decode("utf-8") + + def write_text(self, path: str, data: str): + full_path = self._get_full_path(path) + if self.storage_type == StorageType.FILE: + os.makedirs(os.path.dirname(full_path), exist_ok=True) + with open(full_path, "w") as file: + file.write(data) + + elif self.storage_type == StorageType.S3: + self.client.put_object(self.bucket_name, full_path, data.encode("utf-8"), len(data)) + + def delete(self, path: str): + full_path = self._get_full_path(path) + if self.storage_type == StorageType.FILE: + if os.path.exists(full_path): + os.remove(full_path) + else: + raise FileNotFoundError(f"The file {full_path} does not exist.") + + elif self.storage_type == StorageType.S3: + try: + self.client.remove_object(self.bucket_name, full_path) + except S3Error as e: + raise FileNotFoundError(f"The object {full_path} does not exist in bucket {self.bucket_name}.") from e diff --git a/argilla-v1/src/extralit_v1/storage/singleton.py b/argilla-v1/src/extralit_v1/storage/singleton.py new file mode 100644 index 000000000..679fd3ddd --- /dev/null +++ b/argilla-v1/src/extralit_v1/storage/singleton.py @@ -0,0 +1,17 @@ +import os +from typing import Optional +from extralit_v1.storage.files import FileHandler, StorageType + + +class FileHandlerSingleton: + _instance: Optional[FileHandler] = None + + @classmethod + def get_instance(cls) -> FileHandler: + if cls._instance is None: + base_path = os.getenv("BASE_PATH", "/default/path") + storage_type = os.getenv("STORAGE_TYPE", StorageType.FILE) + bucket_name = os.getenv("BUCKET_NAME", None) + cls._instance = FileHandler(base_path, storage_type, bucket_name) + + return cls._instance diff --git a/argilla-v1/tests/extralit/conftest.py b/argilla-v1/tests/extralit/conftest.py index cf9e2011a..45406e494 100644 --- a/argilla-v1/tests/extralit/conftest.py +++ b/argilla-v1/tests/extralit/conftest.py @@ -1,5 +1,5 @@ from typing import Any, Generator, Optional -from extralit.preprocessing.segment import Segments +from extralit_v1.preprocessing.segment import Segments from minio import S3Error import pytest from fastapi.testclient import TestClient @@ -10,17 +10,18 @@ import pandera as pa from pandera.typing import Index, Series -from extralit.schema.checks import register_check_methods -from extralit.extraction.models.schema import SchemaStructure -from extralit.storage.files import FileHandler, StorageType +from extralit_v1.schema.checks import register_check_methods +from extralit_v1.extraction.models.schema import SchemaStructure +from extralit_v1.storage.files import FileHandler, StorageType register_check_methods() from ..database import SyncTestSession, TestSession, set_task + @pytest.fixture(scope="function") def client(request, mocker: "MockerFixture") -> Generator[TestClient, None, None]: - from extralit.server.app import app + from extralit_v1.server.app import app async def override_get_async_db(): session = TestSession() @@ -36,20 +37,21 @@ async def override_get_async_db(): def mock_dependencies(mocker: "MockerFixture"): mocker.patch("extralit.server.context.vectordb.get_weaviate_client", return_value=MagicMock()) mocker.patch("extralit.server.context.files.get_minio_client", return_value=MagicMock()) - mocker.patch("extralit.server.context.llamaindex.get_langfuse_callback", return_value=MagicMock()) + mocker.patch("extralit.server.context.llamaindex.get_langfuse_callback", return_value=MagicMock()) class MockSchema(pa.DataFrameModel): """ General information about the publication, extracted once per paper. """ + reference: Index[str] = pa.Field(check_name=True) title: Series[str] = pa.Field() authors: Series[str] = pa.Field() journal: Series[str] = pa.Field() publication_year: Series[int] = pa.Field(ge=1900, le=2100) doi: Series[str] = pa.Field(nullable=True) - + class Config: singleton = True @@ -77,7 +79,9 @@ def local_file_handler() -> FileHandler: @pytest.fixture def s3_file_handler() -> FileHandler: # Create a mock FileHandler with S3 storage type - file_handler = FileHandler(base_path='data/preprocessing/', storage_type=StorageType.S3, bucket_name='test-workspace') + file_handler = FileHandler( + base_path="data/preprocessing/", storage_type=StorageType.S3, bucket_name="test-workspace" + ) file_handler.client = MagicMock() - - return file_handler \ No newline at end of file + + return file_handler diff --git a/argilla-v1/tests/extralit/metrics/conftest.py b/argilla-v1/tests/extralit/metrics/conftest.py index 41eaf6aff..3ce5a47ee 100644 --- a/argilla-v1/tests/extralit/metrics/conftest.py +++ b/argilla-v1/tests/extralit/metrics/conftest.py @@ -3,16 +3,17 @@ import pandas as pd import pytest -from extralit.extraction.models.paper import PaperExtraction -from extralit.extraction.models.schema import SchemaStructure -from extralit.schema.checks import register_check_methods +from extralit_v1.extraction.models.paper import PaperExtraction +from extralit_v1.extraction.models.schema import SchemaStructure +from extralit_v1.schema.checks import register_check_methods register_check_methods() + @pytest.fixture def mock_schema_structure() -> SchemaStructure: current_dir = os.path.dirname(os.path.abspath(__file__)) - relative_path = os.path.join(current_dir, '..', 'assets', 'schemas') + relative_path = os.path.join(current_dir, "..", "assets", "schemas") directory_path = os.path.normpath(relative_path) return SchemaStructure.from_dir(directory_path) @@ -20,16 +21,19 @@ def mock_schema_structure() -> SchemaStructure: @pytest.fixture def mock_observation_df(mock_schema_structure, schema_name="Observation") -> pd.DataFrame: - df = pd.DataFrame({ - 'Study_type': ['Hut trial', 'Lab based bioassay'], - 'Country': ['Country1', 'Country2'], - 'Site': ['Site1', 'Site2'], - 'Start_month': [1, 2], - 'Start_year': [2000, 2001], - 'End_month': [3, 1], - 'End_year': [2001, 2002], - 'Time_elapsed': [14.0, 11.0] - }, index=pd.Index(['ref1', 'ref2'], name='reference')) + df = pd.DataFrame( + { + "Study_type": ["Hut trial", "Lab based bioassay"], + "Country": ["Country1", "Country2"], + "Site": ["Site1", "Site2"], + "Start_month": [1, 2], + "Start_year": [2000, 2001], + "End_month": [3, 1], + "End_year": [2001, 2002], + "Time_elapsed": [14.0, 11.0], + }, + index=pd.Index(["ref1", "ref2"], name="reference"), + ) schema = mock_schema_structure[schema_name] # validated_df = schema.validate(df) return df @@ -48,16 +52,19 @@ def mock_paper_extraction_true(mock_schema_structure, mock_observation_df) -> Pa @pytest.fixture def mock_paper_extraction_pred(mock_schema_structure) -> PaperExtraction: - df = pd.DataFrame({ - 'Study_type': ['Hut trial', 'Lab based bioassay'], - 'Country': ['Country1', 'Country2'], - 'Site': ['Site1', 'Site2'], - 'Start_month': [1, 2], - 'Start_year': [2000, 2001], - 'End_month': [3, 1], - 'End_year': [2001, 2002], - # 'Time_elapsed': [16.0, 13.0] - }, index=pd.Index(['ref3', 'ref4'], name='reference')) + df = pd.DataFrame( + { + "Study_type": ["Hut trial", "Lab based bioassay"], + "Country": ["Country1", "Country2"], + "Site": ["Site1", "Site2"], + "Start_month": [1, 2], + "Start_year": [2000, 2001], + "End_month": [3, 1], + "End_year": [2001, 2002], + # 'Time_elapsed': [16.0, 13.0] + }, + index=pd.Index(["ref3", "ref4"], name="reference"), + ) mock_data = { "Observation": df, diff --git a/argilla-v1/tests/extralit/metrics/test_extraction.py b/argilla-v1/tests/extralit/metrics/test_extraction.py index 0c20afb64..022d8931e 100644 --- a/argilla-v1/tests/extralit/metrics/test_extraction.py +++ b/argilla-v1/tests/extralit/metrics/test_extraction.py @@ -1,10 +1,13 @@ from typing import TYPE_CHECKING -from extralit.metrics.extraction import grits_paper +from extralit_v1.metrics.extraction import grits_paper if TYPE_CHECKING: - from extralit.extraction.models.paper import PaperExtraction + from extralit_v1.extraction.models.paper import PaperExtraction -def test_grits_paper_with_pred_missing_column(mock_paper_extraction_true: 'PaperExtraction', mock_paper_extraction_pred: 'PaperExtraction'): + +def test_grits_paper_with_pred_missing_column( + mock_paper_extraction_true: "PaperExtraction", mock_paper_extraction_pred: "PaperExtraction" +): result = grits_paper(mock_paper_extraction_true, mock_paper_extraction_pred) assert result["Observation"].loc[("grits", "con", "precision")] == 1.0 assert result["Observation"].loc[("grits", "con", "recall")] < 1.0 @@ -12,7 +15,9 @@ def test_grits_paper_with_pred_missing_column(mock_paper_extraction_true: 'Paper assert result["ITNCondition"].loc[("grits", "con", "f1")] == 1.0 -def test_grits_paper_with_pred_missing_column_reversed(mock_paper_extraction_true: 'PaperExtraction', mock_paper_extraction_pred: 'PaperExtraction'): +def test_grits_paper_with_pred_missing_column_reversed( + mock_paper_extraction_true: "PaperExtraction", mock_paper_extraction_pred: "PaperExtraction" +): result = grits_paper(mock_paper_extraction_pred, mock_paper_extraction_true) assert result["Observation"].loc[("grits", "con", "precision")] < 1.0 assert result["Observation"].loc[("grits", "con", "recall")] == 1.0 diff --git a/argilla-v1/tests/extralit/metrics/test_grits.py b/argilla-v1/tests/extralit/metrics/test_grits.py index 06013de18..08c4a8bf4 100644 --- a/argilla-v1/tests/extralit/metrics/test_grits.py +++ b/argilla-v1/tests/extralit/metrics/test_grits.py @@ -1,37 +1,37 @@ import pandas as pd -from extralit.metrics.extraction import grits_from_pandas +from extralit_v1.metrics.extraction import grits_from_pandas def test_identical_dataframes_grits_from_pandas(): - df1 = pd.DataFrame({'A': [1, 2], 'B': [3, 4]}) + df1 = pd.DataFrame({"A": [1, 2], "B": [3, 4]}) df2 = df1.copy() - result = grits_from_pandas(df1, df2, format='series') + result = grits_from_pandas(df1, df2, format="series") assert result.loc[("grits", "top", "f1")] == 1.0 assert result.loc[("grits", "con", "f1")] == 1.0 def test_different_dataframes_grits_from_pandas(): - df1 = pd.DataFrame({'A': [1, 2], 'B': [3, 4]}) - df2 = pd.DataFrame({'A': [1, 3], 'B': [2, 4]}) - result = grits_from_pandas(df1, df2, format='series') + df1 = pd.DataFrame({"A": [1, 2], "B": [3, 4]}) + df2 = pd.DataFrame({"A": [1, 3], "B": [2, 4]}) + result = grits_from_pandas(df1, df2, format="series") assert result.loc[("grits", "top", "f1")] == 1.0 assert result.loc[("grits", "con", "f1")] != 1.0 def test_grits_from_pandas_missing_column_in_true_df(): - true_df = pd.DataFrame({'A': [1, 2], 'B': [3, 4]}) - pred_df = pd.DataFrame({'A': [1, 2], 'B': [3, 4], 'C': [5, 6]}) - result = grits_from_pandas(true_df, pred_df, metrics=['con'], format='series', verbose=2) + true_df = pd.DataFrame({"A": [1, 2], "B": [3, 4]}) + pred_df = pd.DataFrame({"A": [1, 2], "B": [3, 4], "C": [5, 6]}) + result = grits_from_pandas(true_df, pred_df, metrics=["con"], format="series", verbose=2) assert result.loc[("grits", "con", "f1")] != 1.0 assert result.loc[("grits", "con", "recall")] == 1.0 def test_grits_from_pandas_missing_column_in_pred_df(): - true_df = pd.DataFrame({'A': [1, 2], 'B': [3, 4], 'C': [5, 6]}) - pred_df = pd.DataFrame({'A': [1, 2], 'B': [3, 4]}) - result = grits_from_pandas(true_df, pred_df, metrics=['con', 'alignment'], verbose=2) - alignment = result['alignment'] + true_df = pd.DataFrame({"A": [1, 2], "B": [3, 4], "C": [5, 6]}) + pred_df = pd.DataFrame({"A": [1, 2], "B": [3, 4]}) + result = grits_from_pandas(true_df, pred_df, metrics=["con", "alignment"], verbose=2) + alignment = result["alignment"] assert alignment.shape == pred_df.shape assert result["grits_con_precision"] == 1.0 assert result["grits_con_f1"] == 0.8 @@ -39,96 +39,95 @@ def test_grits_from_pandas_missing_column_in_pred_df(): def test_grits_from_pandas_empty_true_df(): true_df = pd.DataFrame() - pred_df = pd.DataFrame({'A': [1, 2], 'B': [3, 4]}) - result = grits_from_pandas(true_df, pred_df, metrics=['con'], format='series') + pred_df = pd.DataFrame({"A": [1, 2], "B": [3, 4]}) + result = grits_from_pandas(true_df, pred_df, metrics=["con"], format="series") assert result.loc[("grits", "con", "precision")] == 0.0 assert result.loc[("grits", "con", "recall")] == 1.0 def test_grits_from_pandas_empty_pred_df(): - true_df = pd.DataFrame({'A': [1, 2], 'B': [3, 4]}) + true_df = pd.DataFrame({"A": [1, 2], "B": [3, 4]}) pred_df = pd.DataFrame() - result = grits_from_pandas(true_df, pred_df, metrics=['con'], format='series') + result = grits_from_pandas(true_df, pred_df, metrics=["con"], format="series") assert result.loc[("grits", "con", "precision")] == 1.0 assert result.loc[("grits", "con", "recall")] == 0.0 def test_extra_column_in_pred_df_grits_from_pandas(): - true_df = pd.DataFrame({'A': [1, 2], 'B': [3, 4]}) - pred_df = pd.DataFrame({'A': [1, 2], 'B': [3, 4], 'C': [5, 6]}) - result = grits_from_pandas(true_df, pred_df, format='series') + true_df = pd.DataFrame({"A": [1, 2], "B": [3, 4]}) + pred_df = pd.DataFrame({"A": [1, 2], "B": [3, 4], "C": [5, 6]}) + result = grits_from_pandas(true_df, pred_df, format="series") assert result.loc[("grits", "top", "f1")] != 1.0 assert result.loc[("grits", "con", "f1")] != 1.0 def test_grits_from_pandas_missing_row_in_true_df(): - true_df = pd.DataFrame({'A': [1], 'B': [3]}) - pred_df = pd.DataFrame({'A': [1, 2], 'B': [3, 4]}) - result = grits_from_pandas(true_df, pred_df, metrics=['con', 'alignment'], verbose=2) - alignment = result['alignment'] + true_df = pd.DataFrame({"A": [1], "B": [3]}) + pred_df = pd.DataFrame({"A": [1, 2], "B": [3, 4]}) + result = grits_from_pandas(true_df, pred_df, metrics=["con", "alignment"], verbose=2) + alignment = result["alignment"] assert alignment.shape == true_df.shape assert result["grits_con_recall"] == 1.0 assert result["grits_con_f1"] == 0.8 def test_grits_from_pandas_missing_row_in_pred_df(): - true_df = pd.DataFrame({'A': [1, 2], 'B': [3, 4]}) - pred_df = pd.DataFrame({'A': [1], 'B': [3]}) - result = grits_from_pandas(true_df, pred_df, metrics=['con', 'alignment'], verbose=2) - alignment = result['alignment'] + true_df = pd.DataFrame({"A": [1, 2], "B": [3, 4]}) + pred_df = pd.DataFrame({"A": [1], "B": [3]}) + result = grits_from_pandas(true_df, pred_df, metrics=["con", "alignment"], verbose=2) + alignment = result["alignment"] assert alignment.shape == pred_df.shape assert result["grits_con_precision"] == 1.0 assert result["grits_con_f1"] == 0.8 + def test_identical_dataframes_different_dtypes_grits_from_pandas(): - df1 = pd.DataFrame({'A': [1, 2], 'B': [3.0, 4.0]}) - df2 = pd.DataFrame({'A': ['1', '2'], 'B': ['3', '4']}) - result = grits_from_pandas(df1, df2, format='series') + df1 = pd.DataFrame({"A": [1, 2], "B": [3.0, 4.0]}) + df2 = pd.DataFrame({"A": ["1", "2"], "B": ["3", "4"]}) + result = grits_from_pandas(df1, df2, format="series") assert result.loc[("grits", "top", "f1")] == 1.0 assert result.loc[("grits", "con", "f1")] == 1.0 def test_grits_from_pandas_with_empty_dataframe(): - df1 = pd.DataFrame({'A': [1, 2], 'B': [3, 4]}) + df1 = pd.DataFrame({"A": [1, 2], "B": [3, 4]}) df2 = pd.DataFrame() - result = grits_from_pandas(df1, df2, format='series') + result = grits_from_pandas(df1, df2, format="series") assert result.loc[("grits", "top", "f1")] == 0.0 assert result.loc[("grits", "con", "f1")] == 0.0 - result = grits_from_pandas(df2, df1, format='series') + result = grits_from_pandas(df2, df1, format="series") assert result.loc[("grits", "top", "f1")] == 0.0 assert result.loc[("grits", "con", "f1")] == 0.0 def test_grits_from_pandas_with_excluded_columns(): - df1 = pd.DataFrame({'A': [1, 2], 'B': [3, 4]}) - df2 = pd.DataFrame({'A': [1, 2], 'B': [3, 5]}) - result = grits_from_pandas(df1, df2, format='series', exclude_columns=['B']) + df1 = pd.DataFrame({"A": [1, 2], "B": [3, 4]}) + df2 = pd.DataFrame({"A": [1, 2], "B": [3, 5]}) + result = grits_from_pandas(df1, df2, format="series", exclude_columns=["B"]) assert result.loc[("grits", "top", "f1")] == 1.0 assert result.loc[("grits", "con", "f1")] == 1.0 def test_reduce_column_grits_from_pandas(): - true_df = pd.DataFrame({'A': [1, 2], 'B': [3, 4]}) - pred_df = pd.DataFrame({'A': [1, 2], 'B': [3, 4]}) - result = grits_from_pandas(true_df, pred_df, format='series', reduce='column') - assert result.loc[("grits", "top", "f1"), 'A'] == 1.0 - assert result.loc[("grits", "con", "f1"), 'B'] == 1.0 + true_df = pd.DataFrame({"A": [1, 2], "B": [3, 4]}) + pred_df = pd.DataFrame({"A": [1, 2], "B": [3, 4]}) + result = grits_from_pandas(true_df, pred_df, format="series", reduce="column") + assert result.loc[("grits", "top", "f1"), "A"] == 1.0 + assert result.loc[("grits", "con", "f1"), "B"] == 1.0 def test_different_permutation_grits_from_pandas(): - true_df = pd.DataFrame({'A': [1, 2], 'B': [3, 4]}, index=['x', 'y']) - pred_df = pd.DataFrame({'B': [4, 3], 'A': [2, 1]}, index=['y', 'x']) - result = grits_from_pandas(true_df, pred_df, format='series') + true_df = pd.DataFrame({"A": [1, 2], "B": [3, 4]}, index=["x", "y"]) + pred_df = pd.DataFrame({"B": [4, 3], "A": [2, 1]}, index=["y", "x"]) + result = grits_from_pandas(true_df, pred_df, format="series") assert result.loc[("grits", "top", "f1")] == 1.0 assert result.loc[("grits", "con", "f1")] == 1.0 def test_index_columns_grits_from_pandas(): - true_df = pd.DataFrame({'A': [1, 2], 'B': [3, 4]}, index=['x', 'y']) - pred_df = pd.DataFrame({'A': [2, 1], 'B': [4, 3]}, index=['y', 'x']) - result = grits_from_pandas(true_df, pred_df, format='series', index_columns=['A', 'B']) + true_df = pd.DataFrame({"A": [1, 2], "B": [3, 4]}, index=["x", "y"]) + pred_df = pd.DataFrame({"A": [2, 1], "B": [4, 3]}, index=["y", "x"]) + result = grits_from_pandas(true_df, pred_df, format="series", index_columns=["A", "B"]) assert result.loc[("grits", "top", "f1")] == 1.0 assert result.loc[("grits", "con", "f1")] == 1.0 - - diff --git a/argilla-v1/tests/extralit/metrics/test_utils.py b/argilla-v1/tests/extralit/metrics/test_utils.py index 1e99ebb39..ec3b4d832 100644 --- a/argilla-v1/tests/extralit/metrics/test_utils.py +++ b/argilla-v1/tests/extralit/metrics/test_utils.py @@ -1,66 +1,66 @@ import pandas as pd -from extralit.metrics import utils +from extralit_v1.metrics import utils from tests.extralit.utils import assert_frame_equal def test_harmonize_columns_with_identical_dataframes(): - df1 = pd.DataFrame({'a': [1, 2], 'b': [3, 4]}) - df2 = pd.DataFrame({'a': [1, 2], 'b': [3, 4]}) + df1 = pd.DataFrame({"a": [1, 2], "b": [3, 4]}) + df2 = pd.DataFrame({"a": [1, 2], "b": [3, 4]}) result1, result2 = utils.harmonize_columns(df1, df2) assert_frame_equal(result1, df1) assert_frame_equal(result2, df2) def test_harmonize_columns_with_different_column_order(): - df1 = pd.DataFrame({'a': [1, 2], 'b': [3, 4]}) - df2 = pd.DataFrame({'b': [3, 4], 'a': [1, 2]}) + df1 = pd.DataFrame({"a": [1, 2], "b": [3, 4]}) + df2 = pd.DataFrame({"b": [3, 4], "a": [1, 2]}) result1, result2 = utils.harmonize_columns(df1, df2) assert_frame_equal(result1, df1) assert_frame_equal(result2, df1) def test_reorder_rows_with_identical_dataframes(): - df1 = pd.DataFrame({'a': [1, 2], 'b': [3, 4]}) - df2 = pd.DataFrame({'a': [1, 2], 'b': [3, 4]}) + df1 = pd.DataFrame({"a": [1, 2], "b": [3, 4]}) + df2 = pd.DataFrame({"a": [1, 2], "b": [3, 4]}) result1, result2 = utils.reorder_rows(df1, df2) assert_frame_equal(result1, df1) assert_frame_equal(result2, df2) def test_reorder_rows_with_different_row_order(): - df1 = pd.DataFrame({'a': [1, 2], 'b': [3, 4]}) - df2 = pd.DataFrame({'a': [2, 1], 'b': [4, 3]}) + df1 = pd.DataFrame({"a": [1, 2], "b": [3, 4]}) + df2 = pd.DataFrame({"a": [2, 1], "b": [4, 3]}) result1, result2 = utils.reorder_rows(df1, df2, verbose=True) assert_frame_equal(result1, df1) assert_frame_equal(result2, df1) def test_convert_metrics_to_df_with_series_format(): - metrics = {'a_b': 1.0, 'c_d': 2.0} - result = utils.convert_metrics_to_df(metrics, 'series') + metrics = {"a_b": 1.0, "c_d": 2.0} + result = utils.convert_metrics_to_df(metrics, "series") expected = pd.Series(metrics) - expected.index = expected.index.str.split('_', n=2, expand=True) + expected.index = expected.index.str.split("_", n=2, expand=True) pd.testing.assert_series_equal(result, expected) def test_convert_metrics_to_df_with_dataframe_format(): - metrics = {'a_b': 1.0, 'c_d': 2.0} - result = utils.convert_metrics_to_df(metrics, 'dataframe') + metrics = {"a_b": 1.0, "c_d": 2.0} + result = utils.convert_metrics_to_df(metrics, "dataframe") expected = pd.Series(metrics).to_frame() - expected.index = expected.index.str.split('_', n=2, expand=True) + expected.index = expected.index.str.split("_", n=2, expand=True) assert_frame_equal(result, expected) def test_most_similar_columns_with_identical_dataframes(): - df1 = pd.DataFrame({'a': [1, 2], 'b': [3, 4]}) - df2 = pd.DataFrame({'a': [1, 2], 'b': [3, 4]}) + df1 = pd.DataFrame({"a": [1, 2], "b": [3, 4]}) + df2 = pd.DataFrame({"a": [1, 2], "b": [3, 4]}) result = utils.most_similar_columns(df1, df2) - assert result == ['a', 'b'] + assert result == ["a", "b"] def test_most_similar_columns_with_different_dataframes(): - df1 = pd.DataFrame({'a': [1, 2], 'b': [3, 4]}) - df2 = pd.DataFrame({'a': [1, 2], 'c': [3, 4]}) + df1 = pd.DataFrame({"a": [1, 2], "b": [3, 4]}) + df2 = pd.DataFrame({"a": [1, 2], "c": [3, 4]}) result = utils.most_similar_columns(df1, df2) - assert result == ['a'] + assert result == ["a"] diff --git a/argilla-v1/tests/extralit/preprocessing/test_document.py b/argilla-v1/tests/extralit/preprocessing/test_document.py index 390cfc0d7..9bf0db4cf 100644 --- a/argilla-v1/tests/extralit/preprocessing/test_document.py +++ b/argilla-v1/tests/extralit/preprocessing/test_document.py @@ -1,11 +1,13 @@ import pandas as pd from unittest.mock import MagicMock, patch -from extralit.preprocessing.document import create_or_load_deepdoctection_segments, create_or_load_nougat_segments -from extralit.preprocessing.segment import Segments -from extralit.storage.files import FileHandler +from extralit_v1.preprocessing.document import create_or_load_deepdoctection_segments, create_or_load_nougat_segments +from extralit_v1.preprocessing.segment import Segments +from extralit_v1.storage.files import FileHandler -def test_create_or_load_deepdoctection_segments_load_only(mock_paper: 'pd.Series', local_file_handler: 'MagicMock', mock_deepdoctection: 'MagicMock'): +def test_create_or_load_deepdoctection_segments_load_only( + mock_paper: "pd.Series", local_file_handler: "MagicMock", mock_deepdoctection: "MagicMock" +): local_file_handler.exists.return_value = True local_file_handler.read_text.return_value = Segments().json() @@ -18,9 +20,12 @@ def test_create_or_load_deepdoctection_segments_load_only(mock_paper: 'pd.Series assert texts is not None assert tables is not None assert figures is not None - local_file_handler.read_text.assert_any_call('data/preprocessing/deepdoctection/test-paper/tables.json') + local_file_handler.read_text.assert_any_call("data/preprocessing/deepdoctection/test-paper/tables.json") -def test_create_or_load_deepdoctection_segments_redo(mock_paper: 'pd.Series', local_file_handler: 'MagicMock', mock_deepdoctection: 'MagicMock'): + +def test_create_or_load_deepdoctection_segments_redo( + mock_paper: "pd.Series", local_file_handler: "MagicMock", mock_deepdoctection: "MagicMock" +): local_file_handler.exists.return_value = False mock_deepdoctection.get_dd_analyzer.return_value.analyze.return_value = MagicMock() mock_deepdoctection.get_dd_analyzer.return_value.analyze.return_value.reset_state.return_value = None @@ -36,10 +41,13 @@ def test_create_or_load_deepdoctection_segments_redo(mock_paper: 'pd.Series', lo assert texts is None assert tables is not None assert figures is None - mock_makedirs.assert_called_with('data/preprocessing/deepdoctection/test-paper/tables', exist_ok=True) - mock_deepdoctection.get_dd_analyzer.return_value.analyze.assert_called_with(path='/tmp/test_pdf.pdf') + mock_makedirs.assert_called_with("data/preprocessing/deepdoctection/test-paper/tables", exist_ok=True) + mock_deepdoctection.get_dd_analyzer.return_value.analyze.assert_called_with(path="/tmp/test_pdf.pdf") + -def test_create_or_load_deepdoctection_segments_save(mock_paper: 'pd.Series', local_file_handler: 'MagicMock', mock_deepdoctection: 'MagicMock'): +def test_create_or_load_deepdoctection_segments_save( + mock_paper: "pd.Series", local_file_handler: "MagicMock", mock_deepdoctection: "MagicMock" +): local_file_handler.exists.return_value = False mock_deepdoctection.get_dd_analyzer.return_value.analyze.return_value = MagicMock() mock_deepdoctection.get_dd_analyzer.return_value.analyze.return_value.reset_state.return_value = None @@ -56,21 +64,28 @@ def test_create_or_load_deepdoctection_segments_save(mock_paper: 'pd.Series', lo assert texts is None assert tables is not None assert figures is None - mock_makedirs.assert_called_with('data/preprocessing/deepdoctection/test-paper/tables', exist_ok=True) - mock_deepdoctection.get_dd_analyzer.return_value.analyze.assert_called_with(path='/tmp/test_pdf.pdf') - local_file_handler.write_text.assert_any_call('data/preprocessing/deepdoctection/test-paper/tables.json', tables.json()) + mock_makedirs.assert_called_with("data/preprocessing/deepdoctection/test-paper/tables", exist_ok=True) + mock_deepdoctection.get_dd_analyzer.return_value.analyze.assert_called_with(path="/tmp/test_pdf.pdf") + local_file_handler.write_text.assert_any_call( + "data/preprocessing/deepdoctection/test-paper/tables.json", tables.json() + ) + -def test_create_or_load_deepdoctection_segments_load_from_s3(mock_paper: 'pd.Series', s3_file_handler: 'FileHandler', mock_deepdoctection: 'MagicMock'): +def test_create_or_load_deepdoctection_segments_load_from_s3( + mock_paper: "pd.Series", s3_file_handler: "FileHandler", mock_deepdoctection: "MagicMock" +): # Mock the minio client methods - s3_file_handler.client.stat_object.side_effect = lambda bucket, path: None if path.endswith('page_1.json') else Exception() - s3_file_handler.client.get_object.side_effect = lambda bucket, path: MagicMock(read=lambda: Segments().json().encode('utf-8')) + s3_file_handler.client.stat_object.side_effect = ( + lambda bucket, path: None if path.endswith("page_1.json") else Exception() + ) + s3_file_handler.client.get_object.side_effect = lambda bucket, path: MagicMock( + read=lambda: Segments().json().encode("utf-8") + ) # Mock the file handler methods - s3_file_handler.read_text = MagicMock(side_effect=[ - Segments().json(), Segments().json(), Segments().json() - ]) + s3_file_handler.read_text = MagicMock(side_effect=[Segments().json(), Segments().json(), Segments().json()]) - with patch("glob.glob", return_value=['data/preprocessing/deepdoctection/test-paper/page_1.json']): + with patch("glob.glob", return_value=["data/preprocessing/deepdoctection/test-paper/page_1.json"]): texts, tables, figures = create_or_load_deepdoctection_segments( paper=mock_paper, load_only=True, @@ -80,10 +95,10 @@ def test_create_or_load_deepdoctection_segments_load_from_s3(mock_paper: 'pd.Ser assert texts is not None assert tables is not None assert figures is not None - s3_file_handler.read_text.assert_any_call('data/preprocessing/deepdoctection/test-paper/tables.json') + s3_file_handler.read_text.assert_any_call("data/preprocessing/deepdoctection/test-paper/tables.json") -def test_create_or_load_nougat_segments_load_only(mock_paper: 'pd.Series', local_file_handler: 'MagicMock'): +def test_create_or_load_nougat_segments_load_only(mock_paper: "pd.Series", local_file_handler: "MagicMock"): local_file_handler.exists.return_value = True local_file_handler.read_text.return_value = Segments().json() @@ -96,62 +111,81 @@ def test_create_or_load_nougat_segments_load_only(mock_paper: 'pd.Series', local assert texts is not None assert tables is not None assert figures is not None - local_file_handler.read_text.assert_any_call('data/preprocessing/nougat/test-paper/tables.json') + local_file_handler.read_text.assert_any_call("data/preprocessing/nougat/test-paper/tables.json") -def test_create_or_load_nougat_segments_redo(mock_paper: 'pd.Series', local_file_handler: 'MagicMock', mock_nougat: MagicMock): +def test_create_or_load_nougat_segments_redo( + mock_paper: "pd.Series", local_file_handler: "MagicMock", mock_nougat: MagicMock +): local_file_handler.exists.return_value = False mock_nougat.NougatOCR.return_value.predict.return_value = MagicMock() - with patch("os.makedirs") as mock_makedirs, patch("os.environ", {}), patch("extralit.preprocessing.document.isinstance", return_value=True): + with ( + patch("os.makedirs") as mock_makedirs, + patch("os.environ", {}), + patch("extralit.preprocessing.document.isinstance", return_value=True), + ): texts, tables, figures = create_or_load_nougat_segments( paper=mock_paper, load_only=False, redo=True, file_handler=local_file_handler, - nougat_model=mock_nougat.NougatOCR() + nougat_model=mock_nougat.NougatOCR(), ) assert texts is not None assert tables is not None assert figures is None # mock_makedirs.assert_called_with('data/preprocessing/nougat/test-paper', exist_ok=True) - mock_nougat.NougatOCR.return_value.predict.assert_called_with('/tmp/test_pdf.pdf') + mock_nougat.NougatOCR.return_value.predict.assert_called_with("/tmp/test_pdf.pdf") -def test_create_or_load_nougat_segments_save(mock_paper: 'pd.Series', local_file_handler: 'MagicMock', mock_nougat: 'MagicMock'): +def test_create_or_load_nougat_segments_save( + mock_paper: "pd.Series", local_file_handler: "MagicMock", mock_nougat: "MagicMock" +): local_file_handler.exists.return_value = False mock_nougat.NougatOCR.return_value.predict.return_value = MagicMock() - with patch("os.makedirs") as mock_makedirs, patch("os.environ", {}), patch("extralit.preprocessing.document.isinstance", return_value=True): + with ( + patch("os.makedirs") as mock_makedirs, + patch("os.environ", {}), + patch("extralit.preprocessing.document.isinstance", return_value=True), + ): texts, tables, figures = create_or_load_nougat_segments( paper=mock_paper, load_only=False, redo=False, save=True, file_handler=local_file_handler, - nougat_model=mock_nougat.NougatOCR() + nougat_model=mock_nougat.NougatOCR(), ) assert texts is not None assert tables is not None assert figures is None # mock_makedirs.assert_called_with('data/preprocessing/nougat/test-paper', exist_ok=True) - mock_nougat.NougatOCR.return_value.predict.assert_called_with('/tmp/test_pdf.pdf') - local_file_handler.write_text.assert_any_call('data/preprocessing/nougat/test-paper/tables.json', tables.json()) + mock_nougat.NougatOCR.return_value.predict.assert_called_with("/tmp/test_pdf.pdf") + local_file_handler.write_text.assert_any_call("data/preprocessing/nougat/test-paper/tables.json", tables.json()) -def test_create_or_load_nougat_segments_load_from_s3(mock_paper: 'pd.Series', s3_file_handler: 'FileHandler', mock_nougat: 'MagicMock'): +def test_create_or_load_nougat_segments_load_from_s3( + mock_paper: "pd.Series", s3_file_handler: "FileHandler", mock_nougat: "MagicMock" +): # Mock the minio client methods - s3_file_handler.client.stat_object.side_effect = lambda bucket, path: None if path.endswith('predictions.json') else Exception() - s3_file_handler.client.get_object.side_effect = lambda bucket, path: MagicMock(read=lambda: Segments().json().encode('utf-8')) + s3_file_handler.client.stat_object.side_effect = ( + lambda bucket, path: None if path.endswith("predictions.json") else Exception() + ) + s3_file_handler.client.get_object.side_effect = lambda bucket, path: MagicMock( + read=lambda: Segments().json().encode("utf-8") + ) # Mock the file handler methods - s3_file_handler.read_text = MagicMock(side_effect=[ - Segments().json(), Segments().json(), Segments().json() - ]) + s3_file_handler.read_text = MagicMock(side_effect=[Segments().json(), Segments().json(), Segments().json()]) - with patch("glob.glob", return_value=['data/preprocessing/nougat/test-paper/predictions.json']), patch("extralit.preprocessing.document.isinstance", return_value=True): + with ( + patch("glob.glob", return_value=["data/preprocessing/nougat/test-paper/predictions.json"]), + patch("extralit.preprocessing.document.isinstance", return_value=True), + ): texts, tables, figures = create_or_load_nougat_segments( paper=mock_paper, load_only=True, @@ -161,4 +195,4 @@ def test_create_or_load_nougat_segments_load_from_s3(mock_paper: 'pd.Series', s3 assert texts is not None assert tables is not None assert figures is not None - s3_file_handler.read_text.assert_any_call('data/preprocessing/nougat/test-paper/tables.json') + s3_file_handler.read_text.assert_any_call("data/preprocessing/nougat/test-paper/tables.json") diff --git a/argilla-v1/tests/extralit/server/test_app.py b/argilla-v1/tests/extralit/server/test_app.py index d932826ed..b190bca6b 100644 --- a/argilla-v1/tests/extralit/server/test_app.py +++ b/argilla-v1/tests/extralit/server/test_app.py @@ -5,7 +5,7 @@ import pytest from unittest.mock import MagicMock, patch -from extralit.extraction.models.schema import SchemaStructure +from extralit_v1.extraction.models.schema import SchemaStructure if TYPE_CHECKING: from fastapi.testclient import TestClient @@ -13,6 +13,7 @@ from tests.extralit.helpers import mock_chat_completion_stream_v1 + class CachedOpenAIApiKeys: """ Saves the users' OpenAI API key and OpenAI API type either in @@ -60,6 +61,7 @@ def test_health_check(client: "TestClient"): assert response.status_code == 200 assert response.json() == {"status": "ok"} + def test_schemas(client: "TestClient", mocker: "MockerFixture"): mock_schema_structure = mocker.patch("extralit.server.app.SchemaStructure.from_s3") mock_schema_structure.return_value.ordering = {"schema": "value"} @@ -71,99 +73,91 @@ def test_schemas(client: "TestClient", mocker: "MockerFixture"): @patch("llama_index.llms.openai.base.SyncOpenAI") def test_chat( - MockSyncOpenAI: MagicMock, client: "TestClient", mocker: "MockerFixture", + MockSyncOpenAI: MagicMock, + client: "TestClient", + mocker: "MockerFixture", ): with CachedOpenAIApiKeys(set_fake_key=True): mock_instance = MockSyncOpenAI.return_value - mock_instance.chat.completions.create.return_value = ( - mock_chat_completion_stream_v1(responses=["res", "ponse"]) - ) + mock_instance.chat.completions.create.return_value = mock_chat_completion_stream_v1(responses=["res", "ponse"]) mock_vectordb_contains_any = mocker.patch("extralit.server.app.vectordb_contains_any") mock_vectordb_contains_any.return_value = True - response = client.get("/chat", params={ - "query": "test query", - "workspace": "test-workspace", - "reference": "test-reference", - "k": 5, - "chat_mode": "best", - "llm_model": "gpt-3.5-turbo", - "args": [None], - "kwargs": {} - }) - + response = client.get( + "/chat", + params={ + "query": "test query", + "workspace": "test-workspace", + "reference": "test-reference", + "k": 5, + "chat_mode": "best", + "llm_model": "gpt-3.5-turbo", + "args": [None], + "kwargs": {}, + }, + ) + assert response.status_code == 200 assert response.content == b"response" MockSyncOpenAI.assert_called_once() -def test_extraction( - client: "TestClient", - mocker: "MockerFixture", - schema_structure: "SchemaStructure" -): +def test_extraction(client: "TestClient", mocker: "MockerFixture", schema_structure: "SchemaStructure"): with CachedOpenAIApiKeys(set_fake_key=True): # mock_load_index = mocker.patch("extralit.server.app.load_index") - + mock_extract_schema = mocker.patch("extralit.server.app.extract_schema") mock_extract_schema.return_value = (pd.DataFrame({"col": ["value"]}), MagicMock()) - + mock_schema_structure = mocker.patch("extralit.server.app.SchemaStructure.from_s3") mock_schema_structure.return_value = schema_structure response = client.post( - "/extraction", + "/extraction", json={ "reference": "test-reference", "schema_name": "MockSchema", - "extractions": { - "MockSchema": [{"key": "value"}] - }, + "extractions": {"MockSchema": [{"key": "value"}]}, "columns": ["col"], "headers": ["header"], "types": None, "prompt": "test prompt", - }, - params={ - "workspace": "test-workspace", - "args": [None], - "kwargs": {} - } + }, + params={"workspace": "test-workspace", "args": [None], "kwargs": {}}, ) - + assert response.status_code == 201 assert response.json() == { "data": [{"col": "value", "index": 0}], - 'schema': { - 'fields': [ - {'extDtype': None, 'name': 'index', 'type': 'integer'}, - {'extDtype': None, 'name': 'col','type': 'string'} + "schema": { + "fields": [ + {"extDtype": None, "name": "index", "type": "integer"}, + {"extDtype": None, "name": "col", "type": "string"}, ], - 'pandas_version': '1.4.0', - 'primaryKey': ['index'] + "pandas_version": "1.4.0", + "primaryKey": ["index"], }, } def test_segments(client: "TestClient", mocker: "MockerFixture"): mock_get_nodes_metadata = mocker.patch("extralit.server.app.get_nodes_metadata") - mock_get_nodes_metadata.return_value = [{ - "doc_id": "test-doc-id", - "header": "test-header", - "page_number": 1, - "key": "value", - }] - - response = client.get("/segments/", params={ - "workspace": "test-workspace", - "reference": "test-reference", - "types": ["text"], - "limit": 100 - }) + mock_get_nodes_metadata.return_value = [ + { + "doc_id": "test-doc-id", + "header": "test-header", + "page_number": 1, + "key": "value", + } + ] + + response = client.get( + "/segments/", + params={"workspace": "test-workspace", "reference": "test-reference", "types": ["text"], "limit": 100}, + ) assert response.status_code == 200 - assert response.json() == {'items': [ - {'doc_id': 'test-doc-id', 'header': 'test-header', 'page_number': 1, 'type': None} - ] - } \ No newline at end of file + assert response.json() == { + "items": [{"doc_id": "test-doc-id", "header": "test-header", "page_number": 1, "type": None}] + } diff --git a/argilla-v1/tests/unit/feedback/__init__.py b/argilla-v1/tests/unit/feedback/__init__.py deleted file mode 100644 index 55be41799..000000000 --- a/argilla-v1/tests/unit/feedback/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2021-present, the Recognai S.L. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/argilla-v1/tests/unit/feedback/conftest.py b/argilla-v1/tests/unit/feedback/conftest.py deleted file mode 100644 index f9fcc5065..000000000 --- a/argilla-v1/tests/unit/feedback/conftest.py +++ /dev/null @@ -1,89 +0,0 @@ -# Copyright 2021-present, the Recognai S.L. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import TYPE_CHECKING, List - -import pytest -from argilla_v1.client.feedback.schemas.fields import TextField -from argilla_v1.client.feedback.schemas.metadata import ( - FloatMetadataProperty, - IntegerMetadataProperty, - TermsMetadataProperty, -) -from argilla_v1.client.feedback.schemas.questions import TextQuestion - -if TYPE_CHECKING: - from argilla_v1.client.feedback.schemas.types import ( - AllowedFieldTypes, - AllowedMetadataPropertyTypes, - AllowedQuestionTypes, - ) - - -@pytest.fixture -def rating_question_payload() -> dict: - return { - "name": "label", - "description": "label", - "required": True, - "values": ["1", "2"], - } - - -@pytest.fixture -def label_question_payload() -> dict: - return { - "name": "label", - "description": "label", - "required": True, - "labels": ["1", "2"], - } - - -@pytest.fixture -def ranking_question_payload() -> dict: - return { - "name": "label", - "description": "label", - "required": True, - "values": ["1", "2"], - } - - -@pytest.fixture -def feedback_dataset_guidelines() -> str: - return "guidelines" - - -@pytest.fixture -def feedback_dataset_fields() -> List["AllowedFieldTypes"]: - return [ - TextField(name="text-field", required=True), - ] - - -@pytest.fixture -def feedback_dataset_questions() -> List["AllowedQuestionTypes"]: - return [ - TextQuestion(name="text-question", description="text", required=True), - ] - - -@pytest.fixture -def feedback_dataset_metadata_properties() -> List["AllowedMetadataPropertyTypes"]: - return [ - TermsMetadataProperty(name="terms-metadata", values=["1", "2"]), - IntegerMetadataProperty(name="integer-metadata", min=0, max=10), - FloatMetadataProperty(name="float-metadata", min=0.0, max=10.0), - ] diff --git a/argilla-v1/tests/unit/feedback/dataset/__init__.py b/argilla-v1/tests/unit/feedback/dataset/__init__.py deleted file mode 100644 index 55be41799..000000000 --- a/argilla-v1/tests/unit/feedback/dataset/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2021-present, the Recognai S.L. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/argilla-v1/tests/unit/feedback/dataset/local/__init__.py b/argilla-v1/tests/unit/feedback/dataset/local/__init__.py deleted file mode 100644 index 55be41799..000000000 --- a/argilla-v1/tests/unit/feedback/dataset/local/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2021-present, the Recognai S.L. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/argilla-v1/tests/unit/feedback/dataset/local/test_dataset.py b/argilla-v1/tests/unit/feedback/dataset/local/test_dataset.py deleted file mode 100644 index 31bdb757b..000000000 --- a/argilla-v1/tests/unit/feedback/dataset/local/test_dataset.py +++ /dev/null @@ -1,554 +0,0 @@ -# Copyright 2021-present, the Recognai S.L. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import TYPE_CHECKING, List, Type - -import numpy.array_api -import pytest -from argilla_v1 import RatingQuestion -from argilla_v1.client.feedback.dataset.local.dataset import FeedbackDataset -from argilla_v1.client.feedback.schemas.fields import TextField -from argilla_v1.client.feedback.schemas.metadata import ( - FloatMetadataProperty, - IntegerMetadataProperty, - TermsMetadataProperty, -) -from argilla_v1.client.feedback.schemas.questions import TextQuestion -from argilla_v1.client.feedback.schemas.records import FeedbackRecord -from argilla_v1.client.feedback.schemas.vector_settings import VectorSettings - -if TYPE_CHECKING: - from argilla_v1.client.feedback.schemas.types import ( - AllowedFieldTypes, - AllowedMetadataPropertyTypes, - AllowedQuestionTypes, - ) - - -@pytest.mark.parametrize( - "record", - [ - FeedbackRecord(fields={"required-field": "text"}, metadata={"nested-metadata": {"a": 1}}), - FeedbackRecord( - fields={"required-field": "text", "optional-field": "text"}, - metadata={"int-metadata": 1, "float-metadata": 1.0}, - ), - FeedbackRecord( - fields={"required-field": "text", "optional-field": None}, - metadata={"terms-metadata": "a", "more-metadata": 3}, - ), - FeedbackRecord( - fields={"required-field": "text", "optional-field": None}, - vectors={"vector-1": [1.0, 2.0, 3.0], "vector-2": [1.0, 2.0, 3.0, 4.0]}, - ), - ], -) -def test_add_records_validation(record: "FeedbackRecord") -> None: - dataset = FeedbackDataset( - fields=[TextField(name="required-field", required=True), TextField(name="optional-field", required=False)], - questions=[TextQuestion(name="question", required=True)], - metadata_properties=[ - TermsMetadataProperty(name="terms-metadata", values=["a", "b", "c"]), - IntegerMetadataProperty(name="int-metadata", min=0, max=10), - FloatMetadataProperty(name="float-metadata", min=0.0, max=10.0), - ], - vectors_settings=[ - VectorSettings(name="vector-1", dimensions=3), - VectorSettings(name="vector-2", dimensions=4), - ], - ) - - dataset.add_records(record) - assert len(dataset.records) == 1 - assert dataset.records[0] == record - - -def test_update_records_with_warning() -> None: - dataset = FeedbackDataset( - fields=[TextField(name="required-field")], - questions=[TextQuestion(name="question")], - ) - - with pytest.warns( - UserWarning, - match="`update_records` method only works for `FeedbackDataset` pushed to Argilla." - " If your are working with local data, you can just iterate over the records and update them.", - ): - dataset.update_records( - FeedbackRecord(fields={"required-field": "text"}, metadata={"nested-metadata": {"a": 1}}) - ) - - -@pytest.mark.parametrize( - "record, allow_extra_metadata, exception_cls, exception_msg", - [ - (FeedbackRecord(fields={}, metadata={}), True, ValueError, "required-field\n field required"), - ( - FeedbackRecord(fields={"optional-field": "text"}, metadata={}), - True, - ValueError, - "required-field\n field required", - ), - ( - FeedbackRecord(fields={"required-field": "text"}, metadata={"terms-metadata": "d"}), - True, - ValueError, - "terms-metadata\n Provided 'terms-metadata=d' is not valid, only values in \['a', 'b', 'c'\] are allowed.", - ), - ( - FeedbackRecord(fields={"required-field": "text"}, metadata={"int-metadata": 11}), - True, - ValueError, - "int-metadata\n Provided 'int-metadata=11' is not valid, only values between 0 and 10 are allowed.", - ), - ( - FeedbackRecord(fields={"required-field": "text"}, metadata={"float-metadata": 11.0}), - True, - ValueError, - "float-metadata\n Provided 'float-metadata=11.0' is not valid, only values between 0.0 and 10.0 are allowed.", - ), - ( - FeedbackRecord(fields={"required-field": "text"}, metadata={"extra-metadata": "yes"}), - False, - ValueError, - "extra fields not permitted", - ), - ( - FeedbackRecord( - fields={"required-field": "text"}, - vectors={ - "vector-1": [1.0, 2.0, 3.0, 4.0], - "vector-2": [1.0, 2.0, 3.0, 4.0], - }, - ), - False, - ValueError, - "Vector with name `vector-1` has an invalid expected dimension.", - ), - ( - FeedbackRecord( - fields={"required-field": "text"}, - vectors={"vector-1": [1.0, 2.0, 3.0], "vector-2": [1.0, 2.0, 3.0, 4.0, 5.0]}, - ), - False, - ValueError, - "Vector with name `vector-2` has an invalid expected dimension.", - ), - ( - FeedbackRecord( - fields={"required-field": "text"}, - vectors={ - "vector-1": [1.0, 2.0, 3.0], - "vector-2": [1.0, 2.0, 3.0, 4.0], - "vector-3": [1.0, 2.0, 3.0], - }, - ), - False, - ValueError, - "Vector with name `vector-3` not present on dataset vector settings.", - ), - ], -) -def test_add_records_validation_error( - record: FeedbackRecord, allow_extra_metadata: bool, exception_cls: Exception, exception_msg: str -) -> None: - dataset = FeedbackDataset( - fields=[TextField(name="required-field", required=True), TextField(name="optional-field", required=False)], - questions=[TextQuestion(name="question", required=True)], - metadata_properties=[ - TermsMetadataProperty(name="terms-metadata", values=["a", "b", "c"]), - IntegerMetadataProperty(name="int-metadata", min=0, max=10), - FloatMetadataProperty(name="float-metadata", min=0.0, max=10.0), - ], - vectors_settings=[ - VectorSettings(name="vector-1", dimensions=3), - VectorSettings(name="vector-2", dimensions=4), - ], - allow_extra_metadata=allow_extra_metadata, - ) - - with pytest.raises(exception_cls, match=exception_msg): - dataset.add_records(record) - assert len(dataset.records) == 0 - - -@pytest.mark.parametrize( - "metadata_property", - ( - TermsMetadataProperty(name="new-terms-metadata"), - TermsMetadataProperty(name="new-terms-metadata", values=["a", "b", "c"]), - IntegerMetadataProperty(name="new-integer-metadata"), - IntegerMetadataProperty(name="new-integer-metadata", min=0, max=10), - FloatMetadataProperty(name="new-float-metadata"), - FloatMetadataProperty(name="new-float-metadata", min=0, max=10), - ), -) -def test_add_metadata_property(metadata_property: "AllowedMetadataPropertyTypes") -> None: - dataset = FeedbackDataset( - fields=[ - TextField(name="required-field"), - TextField(name="optional-field", required=False), - ], - questions=[TextQuestion(name="question")], - ) - - new_metadata_property = dataset.add_metadata_property(metadata_property) - assert new_metadata_property.name == metadata_property.name - assert len(dataset.metadata_properties) == 1 - - current_number_of_metadata_properties = len(dataset.metadata_properties) - dataset.add_metadata_property(TermsMetadataProperty(name="new-metadata-property", values=["a", "b", "c"])) - assert len(dataset.metadata_properties) == current_number_of_metadata_properties + 1 - - -@pytest.mark.parametrize("property_class", [IntegerMetadataProperty, FloatMetadataProperty]) -@pytest.mark.parametrize("numpy_type", [numpy.int16, numpy.int32, numpy.int64, numpy.float16, numpy.float32]) -def test_add_record_with_numpy_values(property_class: Type["AllowedMetadataPropertyTypes"], numpy_type: Type) -> None: - dataset = FeedbackDataset( - fields=[ - TextField(name="required-field"), - TextField(name="optional-field", required=False), - ], - questions=[TextQuestion(name="question")], - ) - - metadata_property = property_class(name="numeric_property") - dataset.add_metadata_property(metadata_property) - - property_to_expected_type_msg = {IntegerMetadataProperty: "`int`", FloatMetadataProperty: "`int` or `float`"} - expected_type_msg = property_to_expected_type_msg[property_class] - - value = numpy_type(10.0) - record = FeedbackRecord(fields={"required-field": "text"}, metadata={"numeric_property": value}) - - with pytest.raises( - ValueError, - match=f"Provided 'numeric_property={value}' of type {str(numpy_type)} is not valid, " - f"only values of type {expected_type_msg} are allowed.", - ): - dataset.add_records(record) - - -@pytest.mark.parametrize( - "metadata_property", - ( - TermsMetadataProperty(name="terms-metadata"), - TermsMetadataProperty(name="terms-metadata", values=["a", "b", "c"]), - IntegerMetadataProperty(name="int-metadata"), - IntegerMetadataProperty(name="int-metadata", min=0, max=10), - FloatMetadataProperty(name="float-metadata"), - FloatMetadataProperty(name="float-metadata", min=0, max=10), - ), -) -def test_add_metadata_property_errors(metadata_property: "AllowedMetadataPropertyTypes") -> None: - dataset = FeedbackDataset( - fields=[TextField(name="required-field", required=True), TextField(name="optional-field", required=False)], - questions=[TextQuestion(name="question", required=True)], - metadata_properties=[ - TermsMetadataProperty(name="terms-metadata", values=["a", "b", "c"]), - IntegerMetadataProperty(name="int-metadata", min=0, max=10), - FloatMetadataProperty(name="float-metadata", min=0.0, max=10.0), - ], - ) - - with pytest.raises( - ValueError, match=f"Invalid `metadata_property={metadata_property.name}` provided as it already exists" - ): - _ = dataset.add_metadata_property(metadata_property) - assert len(dataset.metadata_properties) == 3 - - -def test_update_metadata_properties() -> None: - dataset = FeedbackDataset( - fields=[TextField(name="required-field", required=True), TextField(name="optional-field", required=False)], - questions=[TextQuestion(name="question", required=True)], - metadata_properties=[ - TermsMetadataProperty(name="terms-metadata", values=["a", "b", "c"]), - IntegerMetadataProperty(name="int-metadata", min=0, max=10), - FloatMetadataProperty(name="float-metadata", min=0.0, max=10.0), - ], - ) - for metadata_property in dataset.metadata_properties: - metadata_property.title = "new-title" - metadata_property.visible_for_annotators = False - - with pytest.warns( - UserWarning, match="`update_metadata_properties` method is not supported for `FeedbackDataset` datasets" - ): - dataset.update_metadata_properties(dataset.metadata_properties[0]) - - with pytest.warns( - UserWarning, match="`update_metadata_properties` method is not supported for `FeedbackDataset` datasets" - ): - dataset.update_metadata_properties(dataset.metadata_properties) - - -@pytest.mark.parametrize( - "metadata_properties", - ( - [TermsMetadataProperty(name="terms-metadata", values=["a", "b", "c"])], - [IntegerMetadataProperty(name="integer-metadata", min=0, max=10)], - [FloatMetadataProperty(name="float-metadata", min=0, max=10)], - [ - TermsMetadataProperty(name="terms-metadata", values=["a", "b", "c"]), - IntegerMetadataProperty(name="integer-metadata", min=0, max=10), - FloatMetadataProperty(name="float-metadata", min=0, max=10), - ], - ), -) -def test_delete_metadata_properties(metadata_properties: List["AllowedMetadataPropertyTypes"]) -> None: - dataset = FeedbackDataset( - fields=[TextField(name="required-field", required=True), TextField(name="optional-field", required=False)], - questions=[TextQuestion(name="question", required=True)], - metadata_properties=metadata_properties, - ) - - deleted_metadata_properties = dataset.delete_metadata_properties( - [metadata_property.name for metadata_property in metadata_properties] - ) - assert len(dataset.metadata_properties) == 0 - deleted_metadata_properties = ( - deleted_metadata_properties if isinstance(deleted_metadata_properties, list) else [deleted_metadata_properties] - ) - assert all( - metadata_property.name - in [deleted_metadata_property.name for deleted_metadata_property in deleted_metadata_properties] - for metadata_property in metadata_properties - ) - - -def test_delete_metadata_properties_errors() -> None: - dataset = FeedbackDataset( - fields=[TextField(name="required-field", required=True), TextField(name="optional-field", required=False)], - questions=[TextQuestion(name="question", required=True)], - metadata_properties=[ - TermsMetadataProperty(name="terms-metadata", values=["a", "b", "c"]), - IntegerMetadataProperty(name="int-metadata", min=0, max=10), - FloatMetadataProperty(name="float-metadata", min=0.0, max=10.0), - ], - ) - - with pytest.raises( - ValueError, - match="Invalid `metadata_properties=\['invalid-metadata'\]` provided. It cannot be deleted because it does not exist", - ): - _ = dataset.delete_metadata_properties(["invalid-metadata"]) - assert len(dataset.metadata_properties) == 3 - - -def test_delete_vectors_settings() -> None: - dataset = FeedbackDataset( - fields=[TextField(name="field", required=True)], - questions=[TextQuestion(name="question", required=True)], - vectors_settings=[ - VectorSettings(name="vector-settings-1", dimensions=10), - VectorSettings(name="vector-settings-2", dimensions=10), - VectorSettings(name="vector-settings-3", dimensions=10), - ], - ) - - deleted_vectors_settings = dataset.delete_vectors_settings("vector-settings-1") - assert isinstance(deleted_vectors_settings, VectorSettings) - assert len(dataset.vectors_settings) == 2 - - deleted_vectors_settings = dataset.delete_vectors_settings(["vector-settings-2", "vector-settings-3"]) - assert isinstance(deleted_vectors_settings, list) - assert len(deleted_vectors_settings) == 2 - assert len(dataset.vectors_settings) == 0 - - -def test_not_implemented_methods(): - dataset = FeedbackDataset( - fields=[TextField(name="required-field", required=True), TextField(name="optional-field", required=False)], - questions=[TextQuestion(name="question", required=True)], - metadata_properties=[ - TermsMetadataProperty(name="terms-metadata", values=["a", "b", "c"]), - IntegerMetadataProperty(name="int-metadata", min=0, max=10), - FloatMetadataProperty(name="float-metadata", min=0.0, max=10.0), - ], - ) - - with pytest.warns( - UserWarning, match="`sort_by` method is not supported for local datasets and won't take any effect. " - ): - assert dataset.sort_by("field") == dataset - - with pytest.warns( - UserWarning, match="`filter_by` method is not supported for local datasets and won't take any effect. " - ): - assert dataset.filter_by() == dataset - - -def test_init( - feedback_dataset_guidelines: str, - feedback_dataset_fields: List["AllowedFieldTypes"], - feedback_dataset_questions: List["AllowedQuestionTypes"], -) -> None: - dataset = FeedbackDataset( - guidelines=feedback_dataset_guidelines, - fields=feedback_dataset_fields, - questions=feedback_dataset_questions, - allow_extra_metadata=False, - ) - - assert dataset.guidelines == feedback_dataset_guidelines - assert dataset.fields == feedback_dataset_fields - assert dataset.questions == feedback_dataset_questions - assert dataset.allow_extra_metadata == False - - -def test_init_wrong_guidelines( - feedback_dataset_fields: List["AllowedFieldTypes"], feedback_dataset_questions: List["AllowedQuestionTypes"] -) -> None: - with pytest.raises(TypeError, match="Expected `guidelines` to be"): - FeedbackDataset( - guidelines=[], - fields=feedback_dataset_fields, - questions=feedback_dataset_questions, - ) - with pytest.raises(ValueError, match="Expected `guidelines` to be"): - FeedbackDataset( - guidelines="", - fields=feedback_dataset_fields, - questions=feedback_dataset_questions, - ) - - -def test_init_wrong_fields( - feedback_dataset_guidelines: str, feedback_dataset_questions: List["AllowedQuestionTypes"] -) -> None: - with pytest.raises(TypeError, match="Expected `fields` to be a list"): - FeedbackDataset( - guidelines=feedback_dataset_guidelines, - fields=None, - questions=feedback_dataset_questions, - ) - with pytest.raises(TypeError, match="Expected `fields` to be a list of `TextField`"): - FeedbackDataset( - guidelines=feedback_dataset_guidelines, - fields=[{"wrong": "field"}], - questions=feedback_dataset_questions, - ) - with pytest.raises(ValueError, match="At least one field in `fields` must be required"): - FeedbackDataset( - guidelines=feedback_dataset_guidelines, - fields=[TextField(name="test", required=False)], - questions=feedback_dataset_questions, - ) - with pytest.raises(ValueError, match="Expected `fields` to have unique names"): - FeedbackDataset( - guidelines=feedback_dataset_guidelines, - fields=[ - TextField(name="test", required=True), - TextField(name="test", required=True), - ], - questions=feedback_dataset_questions, - ) - - -def test_init_wrong_questions( - feedback_dataset_guidelines: str, feedback_dataset_fields: List["AllowedFieldTypes"] -) -> None: - with pytest.raises(TypeError, match="Expected `questions` to be a list, got"): - FeedbackDataset( - guidelines=feedback_dataset_guidelines, - fields=feedback_dataset_fields, - questions=None, - ) - with pytest.raises( - TypeError, - match="Expected `questions` to be a list of", - ): - FeedbackDataset( - guidelines=feedback_dataset_guidelines, - fields=feedback_dataset_fields, - questions=[{"wrong": "question"}], - ) - with pytest.raises(ValueError, match="At least one question in `questions` must be required"): - FeedbackDataset( - guidelines=feedback_dataset_guidelines, - fields=feedback_dataset_fields, - questions=[ - TextQuestion(name="question-1", required=False), - RatingQuestion(name="question-2", values=[1, 2], required=False), - ], - ) - with pytest.raises(ValueError, match="Expected `questions` to have unique names"): - FeedbackDataset( - guidelines=feedback_dataset_guidelines, - fields=feedback_dataset_fields, - questions=[ - TextQuestion(name="question-1", required=True), - TextQuestion(name="question-1", required=True), - ], - ) - - -def test_init_wrong_metadata_properties( - feedback_dataset_guidelines: str, - feedback_dataset_fields: List["AllowedFieldTypes"], - feedback_dataset_questions: List["AllowedQuestionTypes"], -) -> None: - with pytest.raises(TypeError, match="Expected `metadata_properties` to be a list"): - FeedbackDataset( - guidelines=feedback_dataset_guidelines, - fields=feedback_dataset_fields, - questions=feedback_dataset_questions, - metadata_properties=["wrong type"], - ) - with pytest.raises(ValueError, match="Expected `metadata_properties` to have unique names"): - FeedbackDataset( - guidelines=feedback_dataset_guidelines, - fields=feedback_dataset_fields, - questions=feedback_dataset_questions, - metadata_properties=[ - IntegerMetadataProperty(name="metadata-property-1", min=0, max=10), - IntegerMetadataProperty(name="metadata-property-1", min=0, max=10), - ], - ) - - -@pytest.mark.parametrize( - "metadata_property", - ( - TermsMetadataProperty(name="terms-metadata-diff-name", values=["a", "b", "c"]), - IntegerMetadataProperty(name="int-metadata-diff-name", min=0, max=10), - FloatMetadataProperty(name="float-metadata-diff-name", min=0.0, max=10.0), - ), -) -def test__unique_metadata_property(metadata_property: "AllowedMetadataPropertyTypes") -> None: - dataset = FeedbackDataset( - fields=[TextField(name="required-field", required=True), TextField(name="optional-field", required=False)], - questions=[TextQuestion(name="question", required=True)], - metadata_properties=[ - TermsMetadataProperty(name="terms-metadata", values=["a", "b", "c"]), - IntegerMetadataProperty(name="int-metadata", min=0, max=10), - FloatMetadataProperty(name="float-metadata", min=0.0, max=10.0), - ], - ) - dataset._unique_metadata_property(metadata_property) - - -def test_properties_by_name() -> None: - dataset = FeedbackDataset( - fields=[TextField(name="required-field", required=True), TextField(name="optional-field", required=False)], - questions=[TextQuestion(name="question", required=True)], - metadata_properties=[TermsMetadataProperty(name="terms-metadata", values=["a", "b", "c"])], - ) - assert dataset.field_by_name("mock") is None - assert dataset.question_by_name("mock") is None - assert dataset.metadata_property_by_name("mock") is None - assert isinstance(dataset.field_by_name("required-field"), TextField) - assert isinstance(dataset.question_by_name("question"), TextQuestion) - assert isinstance(dataset.metadata_property_by_name("terms-metadata"), TermsMetadataProperty) diff --git a/argilla-v1/tests/unit/feedback/dataset/remote/__init__.py b/argilla-v1/tests/unit/feedback/dataset/remote/__init__.py deleted file mode 100644 index 55be41799..000000000 --- a/argilla-v1/tests/unit/feedback/dataset/remote/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2021-present, the Recognai S.L. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/argilla-v1/tests/unit/feedback/dataset/remote/test_dataset.py b/argilla-v1/tests/unit/feedback/dataset/remote/test_dataset.py deleted file mode 100644 index b194e4fff..000000000 --- a/argilla-v1/tests/unit/feedback/dataset/remote/test_dataset.py +++ /dev/null @@ -1,362 +0,0 @@ -# Copyright 2021-present, the Recognai S.L. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from datetime import datetime -from typing import Dict -from uuid import uuid4 - -import httpx -import pytest -from argilla_v1 import FeedbackDataset, FeedbackRecord, Workspace -from argilla_v1.client.feedback.dataset.remote.dataset import RemoteFeedbackDataset -from argilla_v1.client.feedback.schemas import SuggestionSchema -from argilla_v1.client.feedback.schemas.remote.fields import RemoteTextField -from argilla_v1.client.feedback.schemas.remote.questions import RemoteTextQuestion -from argilla_v1.client.feedback.schemas.remote.records import RemoteFeedbackRecord -from argilla_v1.client.feedback.schemas.vector_settings import VectorSettings -from argilla_v1.client.sdk.users.models import UserModel, UserRole -from argilla_v1.client.sdk.v1.datasets.models import ( - FeedbackItemModel, - FeedbackListVectorSettingsModel, - FeedbackSuggestionModel, - FeedbackVectorSettingsModel, -) -from argilla_v1.client.sdk.v1.workspaces.models import WorkspaceModel -from pytest_mock import MockerFixture - - -@pytest.fixture() -def test_remote_dataset(mock_httpx_client: httpx.Client) -> RemoteFeedbackDataset: - return RemoteFeedbackDataset( - client=mock_httpx_client, - id=uuid4(), - name="test-remote-dataset", - workspace=Workspace._new_instance( - client=mock_httpx_client, - ws=WorkspaceModel( - id=uuid4(), - name="test-remote-workspace", - inserted_at=datetime.utcnow(), - updated_at=datetime.utcnow(), - ), - ), - fields=[RemoteTextField(id=uuid4(), name="text")], - questions=[RemoteTextQuestion(id=uuid4(), name="text")], - created_at=datetime.utcnow(), - updated_at=datetime.utcnow(), - ) - - -@pytest.fixture() -def test_remote_record( - mock_httpx_client: httpx.Client, test_remote_dataset: RemoteFeedbackDataset -) -> RemoteFeedbackRecord: - return RemoteFeedbackRecord( - id=uuid4(), - client=mock_httpx_client, - fields={"text": "test"}, - metadata={"new": "metadata"}, - question_name_to_id=test_remote_dataset._question_name_to_id, - ) - - -def configure_mock_routes(mock_httpx_client: httpx.Client, mock_routes: Dict) -> None: - def _mock_route(routes: Dict[str, httpx.Response]): - return lambda url, **kwargs: routes[url] - - for method, routes in mock_routes.items(): - getattr(mock_httpx_client, method).side_effect = _mock_route(routes) - - -def create_mock_routes( - test_remote_dataset: RemoteFeedbackDataset, test_remote_record: RemoteFeedbackRecord -) -> Dict[str, Dict[str, httpx.Response]]: - routes = { - "put": {}, - "post": {}, - "delete": {}, - "get": { - "/api/me": httpx.Response( - status_code=200, - content=UserModel( - id=uuid4(), - first_name="test", - username="test", - role=UserRole.owner, - api_key="api.key", - inserted_at=datetime.utcnow(), - updated_at=datetime.utcnow(), - ).json(), - ), - f"/api/v1/me/datasets/{test_remote_dataset.id}/metadata-properties": httpx.Response( - status_code=200, json={"items": []} - ), - f"/api/v1/datasets/{test_remote_dataset.id}/vectors-settings": httpx.Response( - status_code=200, - content=FeedbackListVectorSettingsModel( - items=[ - FeedbackVectorSettingsModel( - id=uuid4(), - name="vector-1", - title="Vector 1", - dimensions=3, - inserted_at=datetime.utcnow(), - updated_at=datetime.utcnow(), - ), - FeedbackVectorSettingsModel( - id=uuid4(), - name="vector-2", - title="Vector 2", - dimensions=4, - inserted_at=datetime.utcnow(), - updated_at=datetime.utcnow(), - ), - ] - ).json(), - ), - }, - "patch": { - f"/api/v1/records/{test_remote_record.id}": httpx.Response( - status_code=200, - content=FeedbackItemModel( - id=test_remote_record.id, - fields=test_remote_record.fields, - inserted_at=datetime.utcnow(), - updated_at=datetime.utcnow(), - ).json(), - ), - f"/api/v1/datasets/{test_remote_dataset.id}/records": httpx.Response( - status_code=204, - ), - }, - } - return routes - - -class TestSuiteRemoteDataset: - def test_update_records( - self, - mock_httpx_client: httpx.Client, - test_remote_dataset: RemoteFeedbackDataset, - test_remote_record: RemoteFeedbackRecord, - ) -> None: - """Test updating records.""" - - mock_routes = create_mock_routes(test_remote_dataset, test_remote_record) - configure_mock_routes(mock_httpx_client, mock_routes) - - test_remote_dataset.update_records(records=[test_remote_record]) - - mock_httpx_client.patch.assert_called_once_with( - url=f"/api/v1/datasets/{test_remote_dataset.id}/records", - json={"items": [ - {"id": str(test_remote_record.id), "fields": {"text": "test"}, "suggestions": [], "metadata": {"new": "metadata"}} - ] - }, - ) - - def test_update_multiple_records( - self, - mock_httpx_client: httpx.Client, - test_remote_dataset: RemoteFeedbackDataset, - test_remote_record: RemoteFeedbackRecord, - ) -> None: - """Test updating records.""" - - mock_routes = create_mock_routes(test_remote_dataset, test_remote_record) - configure_mock_routes(mock_httpx_client, mock_routes) - - mock_httpx_client.patch.return_value = httpx.Response(status_code=204) - - test_remote_dataset.update_records(records=[test_remote_record] * 10) - - assert mock_httpx_client.patch.call_count == 1 - - def test_update_records_with_multiple_suggestions( - self, - mock_httpx_client: httpx.Client, - test_remote_dataset: RemoteFeedbackDataset, - test_remote_record: RemoteFeedbackRecord, - ) -> None: - """Test updating records.""" - mock_routes = create_mock_routes(test_remote_dataset, test_remote_record) - configure_mock_routes(mock_httpx_client, mock_routes) - - test_remote_record.suggestions = [ - SuggestionSchema(question_name="text", value="Test value", score=0.5, agent="test") - ] * 10 - - test_remote_dataset.update_records(records=[test_remote_record] * 10) - - assert mock_httpx_client.patch.call_count == 1 - - def test_update_records_suggestions( - self, - mock_httpx_client: httpx.Client, - test_remote_dataset: RemoteFeedbackDataset, - test_remote_record: RemoteFeedbackRecord, - ) -> None: - expected_suggestion = FeedbackSuggestionModel( - id=uuid4(), - question_id=str(test_remote_dataset.question_by_name("text").id), - value="Test value", - score=0.5, - agent="test", - ) - - mock_routes = create_mock_routes(test_remote_dataset, test_remote_record) - configure_mock_routes(mock_httpx_client, mock_routes) - - test_remote_record.suggestions = [ - SuggestionSchema(question_name="text", value="Test value", score=0.5, agent="test") - ] - - test_remote_dataset.update_records(records=test_remote_record) - - mock_httpx_client.patch.assert_called_with( - url=f"/api/v1/datasets/{test_remote_dataset.id}/records", - # TODO: This should be a list of suggestions - json={ - "items": [ - { - "id": str(test_remote_record.id), - "metadata": {"new": "metadata"}, - "fields": {"text": "test"}, - "suggestions": [ - { - "agent": expected_suggestion.agent, - "question_id": str(expected_suggestion.question_id), - "score": expected_suggestion.score, - "value": expected_suggestion.value, - } - ], - } - ] - }, - ) - - def test_update_records_suggestions_with_already_suggestion( - self, - mock_httpx_client: httpx.Client, - test_remote_dataset: RemoteFeedbackDataset, - test_remote_record: RemoteFeedbackRecord, - ) -> None: - # TODO: Implement - pass - - def test_push_to_huggingface_warnings( - self, mocker: MockerFixture, monkeypatch: pytest.MonkeyPatch, test_remote_dataset: RemoteFeedbackDataset - ) -> None: - monkeypatch.setattr(test_remote_dataset, "pull", lambda: mocker.Mock(FeedbackDataset)) - with pytest.warns( - UserWarning, - match="The dataset is first pulled locally and pushed to Hugging Face after " - "because `push_to_huggingface` is not supported for a `RemoteFeedbackDataset`", - ): - test_remote_dataset.push_to_huggingface("repo_id") - - def test_add_vector_settings( - self, - mock_httpx_client: httpx.Client, - test_remote_dataset: RemoteFeedbackDataset, - test_remote_record: RemoteFeedbackRecord, - ) -> None: - mock_routes = create_mock_routes(test_remote_dataset, test_remote_record) - - expected_name = "mock-vector" - expected_title = "Mock Vector" - expected_dimensions = 100 - mock_routes["post"].update( - { - f"/api/v1/datasets/{test_remote_dataset.id}/vectors-settings": httpx.Response( - status_code=201, - content=FeedbackVectorSettingsModel( - id=uuid4(), - name=expected_name, - title=expected_title, - dimensions=expected_dimensions, - inserted_at=datetime.utcnow(), - updated_at=datetime.utcnow(), - ).json(), - ) - } - ) - - configure_mock_routes(mock_httpx_client, mock_routes) - - test_remote_dataset.add_vector_settings( - vector_settings=VectorSettings(name=expected_name, title=expected_title, dimensions=expected_dimensions) - ) - - mock_httpx_client.post.assert_called_once_with( - url=f"/api/v1/datasets/{test_remote_dataset.id}/vectors-settings", - json={"name": expected_name, "title": expected_title, "dimensions": expected_dimensions}, - ) - - def test_add_records_with_vectors( - self, - mock_httpx_client: httpx.Client, - test_remote_dataset: RemoteFeedbackDataset, - ) -> None: - mock_routes = { - "post": {f"/api/v1/datasets/{test_remote_dataset.id}/records": httpx.Response(status_code=204)}, - "get": { - "/api/me": httpx.Response( - status_code=200, - content=UserModel( - id=uuid4(), - first_name="test", - username="test", - role=UserRole.owner, - api_key="api.key", - inserted_at=datetime.utcnow(), - updated_at=datetime.utcnow(), - ).json(), - ), - f"/api/v1/me/datasets/{test_remote_dataset.id}/metadata-properties": httpx.Response( - status_code=200, json={"items": []} - ), - f"/api/v1/datasets/{test_remote_dataset.id}/vectors-settings": httpx.Response( - status_code=200, - content=FeedbackListVectorSettingsModel( - items=[ - FeedbackVectorSettingsModel( - id=uuid4(), - name="vector-1", - title="Vector 1", - dimensions=3, - inserted_at=datetime.utcnow(), - updated_at=datetime.utcnow(), - ), - FeedbackVectorSettingsModel( - id=uuid4(), - name="vector-2", - title="Vector 2", - dimensions=4, - inserted_at=datetime.utcnow(), - updated_at=datetime.utcnow(), - ), - ] - ).json(), - ), - }, - } - - configure_mock_routes(mock_httpx_client, mock_routes) - test_remote_dataset.add_records(FeedbackRecord(fields={"text": "test"}, vectors={"vector-1": [1.0, 2.0, 3.0]})) - - mock_httpx_client.post.assert_called_once_with( - url=f"/api/v1/datasets/{test_remote_dataset.id}/records", - json={"items": [{"fields": {"text": "test"}, "suggestions": [], "vectors": {"vector-1": [1.0, 2.0, 3.0]}}]}, - ) diff --git a/argilla-v1/tests/unit/feedback/dataset/test_helpers.py b/argilla-v1/tests/unit/feedback/dataset/test_helpers.py deleted file mode 100644 index 530ad81a3..000000000 --- a/argilla-v1/tests/unit/feedback/dataset/test_helpers.py +++ /dev/null @@ -1,224 +0,0 @@ -# Copyright 2021-present, the Recognai S.L. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import TYPE_CHECKING, Dict, List, Type, Union - -import pytest -from argilla_v1.client.feedback.dataset.helpers import generate_pydantic_schema_for_metadata -from argilla_v1.client.feedback.schemas.metadata import ( - FloatMetadataProperty, - IntegerMetadataProperty, - TermsMetadataProperty, -) -from argilla_v1.client.feedback.schemas.remote.metadata import ( - RemoteFloatMetadataProperty, - RemoteIntegerMetadataProperty, - RemoteTermsMetadataProperty, -) - -from tests.pydantic_v1 import ValidationError - -if TYPE_CHECKING: - from argilla_v1.client.feedback.schemas.types import ( - AllowedMetadataPropertyTypes, - AllowedRemoteMetadataPropertyTypes, - ) - - -@pytest.mark.parametrize( - "metadata_properties, validation_data", - [ - ( - [TermsMetadataProperty(name="terms-metadata", values=["a", "b", "c"])], - {"terms-metadata": "a"}, - ), - ( - [IntegerMetadataProperty(name="int-metadata", min=0, max=10)], - {"int-metadata": 1}, - ), - ( - [FloatMetadataProperty(name="float-metadata", min=0.0, max=10.0)], - {"float-metadata": 1.0}, - ), - ( - [ - TermsMetadataProperty(name="terms-metadata", values=["a", "b", "c"]), - IntegerMetadataProperty(name="int-metadata", min=0, max=10), - FloatMetadataProperty(name="float-metadata", min=0.0, max=10.0), - ], - {"terms-metadata": "a", "int-metadata": 1, "float-metadata": 1.0}, - ), - ], -) -def test_generate_pydantic_schema_for_metadata( - metadata_properties: List["AllowedMetadataPropertyTypes"], validation_data: Dict[str, Union[str, int, float]] -) -> None: - MetadataSchema = generate_pydantic_schema_for_metadata( - metadata_properties=metadata_properties, name="MetadataSchema" - ) - assert MetadataSchema(**validation_data) - - -@pytest.mark.parametrize( - "metadata_properties, validation_data, exception_cls, exception_msg", - [ - ( - [TermsMetadataProperty(name="terms-metadata", values=["a", "b", "c"])], - {"terms-metadata": "d"}, - ValidationError, - "terms-metadata\n Provided 'terms-metadata=d' is not valid, only values in \['a', 'b', 'c'\] are allowed", - ), - ( - [TermsMetadataProperty(name="terms-metadata", values=["a", "b", "c"])], - {"terms-metadata": 1}, - ValidationError, - "Provided 'terms-metadata=1' of type is not valid", - ), - ( - [IntegerMetadataProperty(name="int-metadata", min=0, max=10)], - {"int-metadata": -10}, - ValidationError, - "int-metadata\n Provided 'int-metadata=-10' is not valid, only values between 0 and 10 are allowed.", - ), - ( - [IntegerMetadataProperty(name="int-metadata", min=0, max=10)], - {"int-metadata": "wrong"}, - ValidationError, - "Provided 'int-metadata=wrong' of type is not valid", - ), - ( - [IntegerMetadataProperty(name="int-metadata", min=0, max=10)], - {"int-metadata": float("nan")}, - ValidationError, - "Provided 'int-metadata=nan' is not valid, NaN values are not allowed.", - ), - ( - [FloatMetadataProperty(name="float-metadata", min=0.0, max=10.0)], - {"float-metadata": 100.0}, - ValidationError, - "float-metadata\n Provided 'float-metadata=100.0' is not valid, only values between 0.0 and 10.0 are allowed.", - ), - ( - [FloatMetadataProperty(name="float-metadata", min=0.0, max=10.0)], - {"float-metadata": "wrong"}, - ValidationError, - "Provided 'float-metadata=wrong' of type is not valid", - ), - ( - [FloatMetadataProperty(name="float-metadata", min=0.0, max=10.0)], - {"float-metadata": float("nan")}, - ValidationError, - "Provided 'float-metadata=nan' is not valid, NaN values are not allowed.", - ), - ], -) -def test_generate_pydantic_schema_for_metadata_errors( - metadata_properties: List["AllowedMetadataPropertyTypes"], - validation_data: Dict[str, Union[str, int, float]], - exception_cls: Exception, - exception_msg: str, -) -> None: - MetadataSchema = generate_pydantic_schema_for_metadata( - metadata_properties=metadata_properties, name="MetadataSchema" - ) - with pytest.raises(exception_cls, match=exception_msg): - MetadataSchema(**validation_data) - - -@pytest.mark.parametrize( - "metadata_properties, validation_data", - [ - ( - [RemoteTermsMetadataProperty(name="terms-metadata", values=["a", "b", "c"])], - {"terms-metadata": "a"}, - ), - ( - [RemoteIntegerMetadataProperty(name="int-metadata", min=0, max=10)], - {"int-metadata": 1}, - ), - ( - [RemoteFloatMetadataProperty(name="float-metadata", min=0.0, max=10.0)], - {"float-metadata": 1.0}, - ), - ( - [ - RemoteTermsMetadataProperty(name="terms-metadata", values=["a", "b", "c"]), - RemoteIntegerMetadataProperty(name="int-metadata", min=0, max=10), - RemoteFloatMetadataProperty(name="float-metadata", min=0.0, max=10.0), - ], - {"terms-metadata": "a", "int-metadata": 1, "float-metadata": 1.0}, - ), - ], -) -def test_generate_pydantic_schema_for_remote_metadata( - metadata_properties: List["AllowedRemoteMetadataPropertyTypes"], validation_data: Dict[str, Union[str, int, float]] -) -> None: - RemoteMetadataSchema = generate_pydantic_schema_for_metadata( - metadata_properties=metadata_properties, name="RemoteMetadataSchema" - ) - assert RemoteMetadataSchema(**validation_data).dict() == validation_data - - -@pytest.mark.parametrize( - "metadata_properties, validation_data, exception_cls, exception_msg", - [ - ( - [RemoteTermsMetadataProperty(name="terms-metadata", values=["a", "b", "c"])], - {"terms-metadata": "d"}, - ValidationError, - "terms-metadata\n Provided 'terms-metadata=d' is not valid, only values in \['a', 'b', 'c'\] are allowed", - ), - ( - [RemoteTermsMetadataProperty(name="terms-metadata", values=["a", "b", "c"])], - {"terms-metadata": 1}, - ValidationError, - "Provided 'terms-metadata=1' of type is not valid", - ), - ( - [RemoteIntegerMetadataProperty(name="int-metadata", min=0, max=10)], - {"int-metadata": -10}, - ValidationError, - "int-metadata\n Provided 'int-metadata=-10' is not valid, only values between 0 and 10 are allowed.", - ), - ( - [RemoteIntegerMetadataProperty(name="int-metadata", min=0, max=10)], - {"int-metadata": "wrong"}, - ValidationError, - "Provided 'int-metadata=wrong' of type is not valid", - ), - ( - [RemoteFloatMetadataProperty(name="float-metadata", min=0.0, max=10.0)], - {"float-metadata": 100.0}, - ValidationError, - "float-metadata\n Provided 'float-metadata=100.0' is not valid, only values between 0.0 and 10.0 are allowed.", - ), - ( - [RemoteFloatMetadataProperty(name="float-metadata", min=0.0, max=10.0)], - {"float-metadata": "wrong"}, - ValidationError, - "Provided 'float-metadata=wrong' of type is not valid", - ), - ], -) -def test_generate_pydantic_schema_for_remote_metadata_errors( - metadata_properties: List["AllowedRemoteMetadataPropertyTypes"], - validation_data: Dict[str, Union[str, int, float]], - exception_cls: Type[Exception], - exception_msg: str, -) -> None: - RemoteMetadataSchema = generate_pydantic_schema_for_metadata( - metadata_properties=metadata_properties, name="RemoteMetadataSchema" - ) - with pytest.raises(exception_cls, match=exception_msg): - RemoteMetadataSchema(**validation_data) diff --git a/argilla-v1/tests/unit/feedback/integrations/__init__.py b/argilla-v1/tests/unit/feedback/integrations/__init__.py deleted file mode 100644 index 55be41799..000000000 --- a/argilla-v1/tests/unit/feedback/integrations/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2021-present, the Recognai S.L. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/argilla-v1/tests/unit/feedback/integrations/huggingface/__init__.py b/argilla-v1/tests/unit/feedback/integrations/huggingface/__init__.py deleted file mode 100644 index 55be41799..000000000 --- a/argilla-v1/tests/unit/feedback/integrations/huggingface/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2021-present, the Recognai S.L. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/argilla-v1/tests/unit/feedback/integrations/huggingface/card/__init__.py b/argilla-v1/tests/unit/feedback/integrations/huggingface/card/__init__.py deleted file mode 100644 index 55be41799..000000000 --- a/argilla-v1/tests/unit/feedback/integrations/huggingface/card/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2021-present, the Recognai S.L. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/argilla-v1/tests/unit/feedback/integrations/huggingface/card/test__dataset_card.py b/argilla-v1/tests/unit/feedback/integrations/huggingface/card/test__dataset_card.py deleted file mode 100644 index c044a4b2c..000000000 --- a/argilla-v1/tests/unit/feedback/integrations/huggingface/card/test__dataset_card.py +++ /dev/null @@ -1,151 +0,0 @@ -# Copyright 2021-present, the Recognai S.L. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -import re -from typing import TYPE_CHECKING, List -from uuid import uuid4 - -import pytest -from argilla_v1.client.feedback.integrations.huggingface.card import ArgillaDatasetCard -from argilla_v1.client.feedback.schemas.fields import TextField -from argilla_v1.client.feedback.schemas.metadata import ( - FloatMetadataProperty, - IntegerMetadataProperty, - TermsMetadataProperty, -) -from argilla_v1.client.feedback.schemas.questions import ( - LabelQuestion, - MultiLabelQuestion, - RankingQuestion, - RatingQuestion, - TextQuestion, -) -from argilla_v1.client.feedback.schemas.records import FeedbackRecord -from argilla_v1.client.feedback.schemas.types import AllowedMetadataPropertyTypes -from argilla_v1.client.feedback.schemas.vector_settings import VectorSettings -from huggingface_hub import DatasetCardData - -if TYPE_CHECKING: - from argilla_v1.client.feedback.schemas import FeedbackRecord - from argilla_v1.client.feedback.schemas.types import AllowedFieldTypes, AllowedQuestionTypes - - -class TestSuiteArgillaDatasetCard: - @pytest.mark.parametrize( - "repo_id,fields,questions,guidelines,metadata_properties,vectors_settings,record", - [ - ( - f"argilla/dataset-card-{uuid4()}", - [TextField(name="text-field")], - [ - TextQuestion(name="text-question"), - RatingQuestion(name="rating-question", values=[1, 2, 3]), - LabelQuestion(name="label-question", labels=["a", "b", "c"]), - MultiLabelQuestion(name="multi-label-question", labels=["a", "b", "c"]), - RankingQuestion(name="ranking-question", values=["a", "b", "c"]), - ], - "## Guidelines", - [ - TermsMetadataProperty(name="color", values=["red", "blue"]), - IntegerMetadataProperty(name="day", min=0, max=31, visible_for_annotators=False), - FloatMetadataProperty(name="price", min=0, max=100), - ], - [VectorSettings(name="float-vector", dimensions=2)], - - FeedbackRecord( - fields={"text-field": "text"}, - responses=[ - { - "values": { - "text-question": {"value": "text"}, - "rating-question": {"value": 1}, - "label-question": {"value": "a"}, - "multi-label-question": {"value": ["a", "b"]}, - "ranking-question": {"value": ["a", "b", "c"]}, - }, - "user_id": str(uuid4()), - "status": "submitted", - }, - ], - suggestions=[ - { - "question_name": "text-question", - "value": "text", - }, - { - "question_name": "rating-question", - "value": 1, - }, - { - "question_name": "label-question", - "value": "a", - }, - { - "question_name": "multi-label-question", - "value": ["a", "b"], - }, - { - "question_name": "ranking-question", - "value": ["a", "b", "c"], - }, - ], - vectors={"float-vector": [1.0, 2.0]}, - external_id="external-id-1", - ), - ) - ], - ) - def test_from_template( - self, - repo_id: str, - fields: List["AllowedFieldTypes"], - questions: List["AllowedQuestionTypes"], - guidelines: str, - metadata_properties: List[AllowedMetadataPropertyTypes], - vectors_settings: List[VectorSettings], - record: FeedbackRecord, - ) -> None: - card = ArgillaDatasetCard.from_template( - card_data=DatasetCardData( - language="en", - size_categories="n<1K", - tags=["rlfh", "argilla", "human-feedback"], - ), - template_path=ArgillaDatasetCard.default_template_path, - repo_id=repo_id, - argilla_fields=fields, - argilla_questions=questions, - argilla_guidelines=guidelines, - argilla_metadata_properties=metadata_properties, - argilla_vectors_settings=vectors_settings, - argilla_record=json.loads(record.json()), - huggingface_record=record.json(), - ) - - assert isinstance(card, ArgillaDatasetCard) - assert card.default_template_path == ArgillaDatasetCard.default_template_path - assert card.content.__contains__(f"# Dataset Card for {repo_id.split('/')[1]}") - assert all(field.name in card.content for field in fields) - assert all(question.name in card.content for question in questions) - assert all(vector_settings.name in card.content for vector_settings in vectors_settings) - assert guidelines in card.content - assert re.search("\| color \| color \| terms \| \['red', 'blue'\] \| True \|", card.content) - assert re.search("\| day \| day \| integer \| 0 - 31 \| False \|", card.content) - assert re.search("\| price \| price \| float \| 0\.0 - 100\.0 \| True \|", card.content) - - # In case we implement new metadata_property types and forget about the poor little dataset cards ... - for allowed_metadata_type in AllowedMetadataPropertyTypes.__args__: - if allowed_metadata_type not in [type(metadata) for metadata in metadata_properties]: - raise NotImplementedError(f"ArgillaDatasetCard not tested for '{allowed_metadata_type}'.") diff --git a/argilla-v1/tests/unit/feedback/integrations/huggingface/test_dataset.py b/argilla-v1/tests/unit/feedback/integrations/huggingface/test_dataset.py deleted file mode 100644 index 27a025ad5..000000000 --- a/argilla-v1/tests/unit/feedback/integrations/huggingface/test_dataset.py +++ /dev/null @@ -1,67 +0,0 @@ -# Copyright 2021-present, the Recognai S.L. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any, Dict - -import pytest -from argilla_v1 import SuggestionSchema -from argilla_v1.client.feedback.dataset.local.dataset import FeedbackDataset -from argilla_v1.client.feedback.integrations.huggingface.dataset import HuggingFaceDatasetMixin -from argilla_v1.client.feedback.schemas.fields import TextField -from argilla_v1.client.feedback.schemas.questions import MultiLabelQuestion, TextQuestion -from argilla_v1.client.feedback.schemas.records import FeedbackRecord - - -class TestSuiteHuggingFaceDatasetMixin: - @pytest.mark.parametrize( - "record, hf_record", - [ - (FeedbackRecord(fields={"required-field": "value"}), {"required-field": "value", "optional-field": None}), - ( - FeedbackRecord(fields={"required-field": "value", "optional-field": "value"}), - {"required-field": "value", "optional-field": "value"}, - ), - ], - ) - def test__huggingface_format(self, record: FeedbackRecord, hf_record: Dict[str, Any]) -> None: - dataset = FeedbackDataset( - fields=[TextField(name="required-field", required=True), TextField(name="optional-field", required=False)], - questions=[TextQuestion(name="question", required=True)], - ) - dataset.add_records([record]) - - hf_dataset = HuggingFaceDatasetMixin._huggingface_format(dataset=dataset) - assert all(field.name in hf_dataset.features for field in dataset.fields) - assert hf_record == { - key: value for key, value in hf_dataset[0].items() if key in [field.name for field in dataset.fields] - } - - def test_format_with_multi_score(self): - dataset = FeedbackDataset( - fields=[TextField(name="text")], - questions=[MultiLabelQuestion(name="topics", labels=["politics", "sports", "economy"])], - ) - dataset.add_records( - [ - FeedbackRecord( - fields={"text": "text"}, - suggestions=[ - SuggestionSchema(value=["politics", "sports"], question_name="topics", score=[0.5, 0.5]) - ], - ) - ] - ) - - hf_dataset = HuggingFaceDatasetMixin._huggingface_format(dataset=dataset) - assert hf_dataset[0]["topics-suggestion-metadata"]["score"] == [0.5, 0.5] diff --git a/argilla-v1/tests/unit/feedback/integrations/huggingface/test_model_card.py b/argilla-v1/tests/unit/feedback/integrations/huggingface/test_model_card.py deleted file mode 100644 index 160f2b682..000000000 --- a/argilla-v1/tests/unit/feedback/integrations/huggingface/test_model_card.py +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright 2021-present, the Recognai S.L. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dataclasses import dataclass -from typing import Any, Dict - -import pytest -from argilla_v1.client.feedback.integrations.huggingface.model_card.model_card import ( - _prepare_dict_for_comparison, - _updated_arguments, -) -from argilla_v1.training.utils import get_default_args -from transformers import TrainingArguments - -default_transformer_args = get_default_args(TrainingArguments.__init__) -default_transformer_args_1 = default_transformer_args.copy() -default_transformer_args_1.update({"output_dir": None, "warmup_steps": 100}) -default_transformer_args_2 = default_transformer_args.copy() -default_transformer_args_2.update({"output_dir": {"nested_name": "test"}}) -default_transformer_args_3 = default_transformer_args.copy() -default_transformer_args_3.update({"output_dir": [1.2, 3, "value"]}) - - -@dataclass -class Dummy: - # Test a random class, it could be a loss function passed as a callable, or an instance - # of one for example. - pass - - -default_transformer_args_4 = default_transformer_args.copy() -default_transformer_args_4.update({"output_dir": Dummy, "other": Dummy()}) - - -@pytest.mark.parametrize( - "current_kwargs, new_kwargs", - ( - (default_transformer_args_1, {"warmup_steps": 100}), - (default_transformer_args_2, {"output_dir": {"nested_name": "test"}}), - (default_transformer_args_3, {"output_dir": [1.2, 3, "value"]}), - (default_transformer_args_4, {"output_dir": Dummy, "other": Dummy()}), - ), -) -def test_updated_kwargs(current_kwargs: Dict[str, Any], new_kwargs: Dict[str, Any]): - # Using only the Transformer's TrainingArguments as an example, no need to check if the arguments are correct - - new_arguments = _updated_arguments(default_transformer_args, current_kwargs) - assert set(_prepare_dict_for_comparison(new_arguments).items()) == set( - _prepare_dict_for_comparison(new_kwargs).items() - ) diff --git a/argilla-v1/tests/unit/feedback/integrations/test_sentencetransformers.py b/argilla-v1/tests/unit/feedback/integrations/test_sentencetransformers.py deleted file mode 100644 index 3c78f22a0..000000000 --- a/argilla-v1/tests/unit/feedback/integrations/test_sentencetransformers.py +++ /dev/null @@ -1,111 +0,0 @@ -# Copyright 2021-present, the Recognai S.L. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import List -from unittest.mock import MagicMock - -import numpy as np -import pytest -from argilla_v1.client.feedback.dataset.local.dataset import FeedbackDataset -from argilla_v1.client.feedback.integrations.sentencetransformers import SentenceTransformersExtractor -from argilla_v1.client.feedback.schemas.fields import TextField -from argilla_v1.client.feedback.schemas.questions import TextQuestion -from argilla_v1.client.feedback.schemas.records import FeedbackRecord -from argilla_v1.client.feedback.schemas.vector_settings import VectorSettings -from sentence_transformers import SentenceTransformer - - -@pytest.fixture(scope="function") -def records() -> List[FeedbackRecord]: - return [ - FeedbackRecord(fields={"field_1": "This is a test", "field_2": "This is a test"}), - FeedbackRecord( - fields={ - "field_1": "This is a test", - } - ), - FeedbackRecord( - fields={"field_1": "This is a test", "field_2": "This is a test"}, - ), - ] - - -@pytest.fixture(scope="function") -def dataset() -> FeedbackDataset: - ds = FeedbackDataset( - fields=[ - TextField(name="field_1"), - TextField(name="field_2", required=False), - ], - questions=[ - TextQuestion(name="question_1"), - ], - ) - return ds - - -@pytest.fixture(scope="session") -def st_extractor() -> SentenceTransformersExtractor: - model = SentenceTransformer("TaylorAI/bge-micro-v2") - model.get_sentence_embedding_dimension = MagicMock(return_value=1) - model.encode = MagicMock(return_value=np.array([1])) - return SentenceTransformersExtractor(model=model) - - -@pytest.mark.usefixtures("st_extractor", "dataset", "records") -def test_st_extractor( - st_extractor: SentenceTransformersExtractor, dataset: FeedbackDataset, records: List[FeedbackRecord] -): - dataset.add_records(records) - st_extractor = SentenceTransformersExtractor() - assert isinstance(st_extractor.model, SentenceTransformer) - assert st_extractor.embedding_dim == st_extractor.model.get_sentence_embedding_dimension() - assert st_extractor.show_progress - new_dataset = FeedbackDataset( - fields=dataset.fields, - questions=dataset.questions, - vectors_settings=[VectorSettings(name="field_1", dimensions=st_extractor.embedding_dim)], - ) - st_extractor._create_vector_settings = MagicMock(return_value=new_dataset) - new_records = [] - for record in records: - new_records.append(FeedbackRecord(fields=record.fields, vectors={"field_1": np.array([1]).tolist()})) - st_extractor._encode_single_field = MagicMock(return_value=records) - st_extractor.update_records(records, fields=["field_1"]) - st_extractor._encode_single_field.assert_called_once_with(records, "field_1", False) - st_extractor.update_dataset(dataset, update_records=True) - st_extractor._encode_single_field.call_count == len(dataset.fields) + 1 - st_extractor._create_vector_settings.call_count == len(dataset.fields) - - -@pytest.mark.fixtures("st_extractor", "dataset") -def test_create_vector_settings(st_extractor: SentenceTransformersExtractor, dataset: FeedbackDataset): - dataset = st_extractor._create_vector_settings(dataset, fields=["field_1"]) - assert dataset.vectors_settings == [VectorSettings(name="field_1", dimensions=st_extractor.embedding_dim)] - - -@pytest.mark.fixtures("st_extractor", "records") -def test_encode_single_field(st_extractor: SentenceTransformersExtractor, records: List[FeedbackRecord]): - records = st_extractor._encode_single_field(records=records, field="field_1", overwrite=False) - assert records[0].vectors["field_1"] == np.array([1]) - - -@pytest.mark.fixtures("st_extractor") -def test_update_dataset_with_invalid_fields(st_extractor: SentenceTransformersExtractor): - dataset = FeedbackDataset( - fields=[TextField(name="text")], - questions=[TextQuestion(name="question")], - ) - with pytest.raises(ValueError): - st_extractor.update_dataset(dataset, fields=["my_fake_field"]) diff --git a/argilla-v1/tests/unit/feedback/integrations/test_textdescriptives.py b/argilla-v1/tests/unit/feedback/integrations/test_textdescriptives.py deleted file mode 100644 index bb872207c..000000000 --- a/argilla-v1/tests/unit/feedback/integrations/test_textdescriptives.py +++ /dev/null @@ -1,302 +0,0 @@ -# Copyright 2021-present, the Recognai S.L. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from unittest.mock import MagicMock - -import pandas as pd -import pytest -from argilla_v1.client.feedback.dataset import FeedbackDataset -from argilla_v1.client.feedback.integrations.textdescriptives import TextDescriptivesExtractor -from argilla_v1.client.feedback.schemas.fields import TextField -from argilla_v1.client.feedback.schemas.metadata import ( - FloatMetadataProperty, - IntegerMetadataProperty, - TermsMetadataProperty, -) -from argilla_v1.client.feedback.schemas.questions import TextQuestion -from argilla_v1.client.feedback.schemas.records import FeedbackRecord - - -@pytest.fixture -def records(): - return - - -@pytest.fixture(scope="session") -def td_extractor() -> TextDescriptivesExtractor: - return TextDescriptivesExtractor() - - -@pytest.mark.fixtures("td_extractor") -@pytest.mark.parametrize( - "records", - [ - [ - FeedbackRecord( - fields={"required-field": "This is a test.", "optional-field": None}, - ), - FeedbackRecord( - fields={"required-field": "This is another test.", "optional-field": None}, - ), - ], - [ - FeedbackRecord( - fields={"required-field": "This is a test.", "optional-field": "This is also a test."}, - ), - FeedbackRecord( - fields={"required-field": "This is another test.", "optional-field": "This is also another test."}, - ), - ], - [ - FeedbackRecord( - fields={"required-field": "This is a test."}, - metadata={"text_n_tokens": 5, "text_n_unique_tokens": 4}, - ), - FeedbackRecord( - fields={"required-field": "This is another test."}, - metadata={"text_n_tokens": 5, "text_n_unique_tokens": 4}, - ), - ], - ], -) -def test_extract_metrics_for_single_field(records, td_extractor: TextDescriptivesExtractor) -> None: - field_metrics = td_extractor._extract_metrics_for_single_field(records, "required-field") - assert field_metrics["required-field_n_tokens"].values[0] == 4 - assert len(field_metrics) == len(records) # Assert the number of rows in the DataFrame - assert isinstance(field_metrics, pd.DataFrame) # Assert the data type of the DataFrame - assert "required-field_n_tokens" in field_metrics.columns # Assert the presence of the column - assert field_metrics["required-field_n_tokens"].values[0] == 4 # Assert the value of the column - assert "required-field" not in field_metrics.columns # Assert that text column has been dropped - assert not field_metrics.isnull().values.any() # Assert no columns with NaN values - assert "optional-field" not in field_metrics.columns - - -@pytest.mark.fixtures("td_extractor") -def test_extract_metrics_for_single_field_empty_field(td_extractor: TextDescriptivesExtractor) -> None: - records = [ - FeedbackRecord( - fields={"required-field": "This is a test.", "optional-field": None}, - ), - FeedbackRecord( - fields={"required-field": "This is another test.", "optional-field": None}, - ), - ] - field_metrics = td_extractor._extract_metrics_for_single_field(records, "optional-field") - assert field_metrics is None - - -@pytest.mark.fixtures("td_extractor") -@pytest.mark.parametrize( - "records", - [ - [ - FeedbackRecord( - fields={"required-field": "This is a test.", "optional-field": None}, - ), - FeedbackRecord( - fields={"required-field": "This is another test.", "optional-field": None}, - ), - ], - [ - FeedbackRecord( - fields={"required-field": "This is a test.", "optional-field": "This is also a test."}, - ), - FeedbackRecord( - fields={"required-field": "This is another test.", "optional-field": "This is also another test."}, - ), - ], - [ - FeedbackRecord( - fields={"required-field": "This is a test."}, - metadata={"text_n_tokens": 5, "text_n_unique_tokens": 4}, - ), - FeedbackRecord( - fields={"required-field": "This is another test."}, - metadata={"text_n_tokens": 5, "text_n_unique_tokens": 4}, - ), - ], - ], -) -def test_extract_metrics_for_all_fields(records, td_extractor: TextDescriptivesExtractor) -> None: - expected_fields = [key for record in records for key, value in record.fields.items() if value is not None] - field_metrics = td_extractor._extract_metrics_for_all_fields(records, overwrite=False, fields=expected_fields) - assert field_metrics["required-field_n_tokens"].values[0] == 4 - assert all(any(field == col or field + "_" in col for col in field_metrics.columns) for field in expected_fields) - - -@pytest.mark.fixtures("td_extractor") -def test_cast_to_python_types(td_extractor: TextDescriptivesExtractor) -> None: - df = pd.DataFrame( - { - "col_int": [1, 2, 3], - "col_bool": [True, False, True], - "col_float": [1.234, 2.345, 3.456], - } - ) - df_result = td_extractor._cast_to_python_types(df) - assert df_result["col_int"].dtype == "int32" or df_result["col_int"].dtype == "int64" - assert df_result["col_bool"].dtype == "object" - assert df_result["col_float"].dtype == "float64" or df_result["col_float"].dtype == "float32" - assert df_result["col_float"].values[0] == 1.23 - assert isinstance(df_result, pd.DataFrame) - - -@pytest.mark.fixtures("td_extractor") -def test_clean_column_name(td_extractor: TextDescriptivesExtractor) -> None: - assert td_extractor._clean_column_name("Test_Col") == "test_col" - assert td_extractor._clean_column_name("test col") == "test_col" - assert td_extractor._clean_column_name("Test-Col") == "test_col" - assert td_extractor._clean_column_name("Test.Col") == "test_col" - - -@pytest.mark.fixtures("td_extractor") -@pytest.mark.parametrize( - "column_name, expected_prop_type, expected_title, expected_visible, expected_type, expected_values", - [ - ("col_int", IntegerMetadataProperty, "Col Int", True, "integer", None), - ("col_bool", TermsMetadataProperty, "Col Bool", True, "terms", ["True", "False"]), - ("col_float", FloatMetadataProperty, "Col Float", True, "float", None), - ("col_obj", TermsMetadataProperty, "Col Obj", True, "terms", ["value_1", "value_2", "value_3"]), - ], -) -def test_create_metadata_properties( - column_name, - expected_prop_type, - expected_title, - expected_visible, - expected_type, - expected_values, - td_extractor: TextDescriptivesExtractor, -) -> None: - df = pd.DataFrame( - { - "col_int": pd.Series([1, 2, 3], dtype="int32"), - "col_bool": pd.Series([True, False, True], dtype="bool"), - "col_float": pd.Series([1.234, 2.345, 3.456], dtype="float64"), - "col_obj": pd.Series(["value_1", "value_2", "value_3"], dtype="object"), - } - ) - properties = td_extractor._create_metadata_properties(df) - prop = next((prop for prop in properties if prop.name == column_name), None) - assert isinstance(prop, expected_prop_type) - assert prop.name == column_name - assert prop.title == expected_title - assert prop.visible_for_annotators == expected_visible - assert prop.type == expected_type - if isinstance(prop, TermsMetadataProperty): - assert prop.values == expected_values - - -@pytest.mark.fixtures("td_extractor") -@pytest.mark.parametrize( - "records", - [ - [ - FeedbackRecord( - fields={"required-field": "This is a test.", "optional-field": None}, - ), - FeedbackRecord( - fields={"required-field": "This is another test.", "optional-field": None}, - ), - ], - [ - FeedbackRecord( - fields={"required-field": "This is a test.", "optional-field": "This is also a test."}, - ), - FeedbackRecord( - fields={"required-field": "This is another test.", "optional-field": "This is also another test."}, - ), - ], - [ - FeedbackRecord( - fields={"required-field": "This is a test."}, - metadata={"text_n_tokens": 5, "text_n_unique_tokens": 4}, - ), - FeedbackRecord( - fields={"required-field": "This is another test."}, - metadata={"text_n_tokens": 5, "text_n_unique_tokens": 4}, - ), - ], - ], -) -def test_update_records_metrics_extracted(records, td_extractor: TextDescriptivesExtractor) -> None: - extracted_metrics = pd.DataFrame({"text_n_tokens": [4, 5]}) - td_extractor._extract_metrics_for_all_fields = MagicMock(return_value=extracted_metrics) - td_extractor._cast_to_python_types = MagicMock(return_value=extracted_metrics) - td_extractor._clean_column_name = MagicMock(side_effect=lambda col: col) - td_extractor._add_text_descriptives_to_metadata = MagicMock(return_value=records) - updated_records = td_extractor.update_records(records) - td_extractor._extract_metrics_for_all_fields.assert_called_once() - td_extractor._cast_to_python_types.assert_called_once_with(extracted_metrics) - td_extractor._clean_column_name.assert_called_with("text_n_tokens") - td_extractor._add_text_descriptives_to_metadata.assert_called_once_with(records, extracted_metrics) - assert updated_records == records - - -@pytest.mark.fixtures("td_extractor") -def test_update_records_no_metrics_extracted(td_extractor: TextDescriptivesExtractor): - records = [ - FeedbackRecord(fields={"text": "This is a test."}), - FeedbackRecord(fields={"text": "This is another test."}), - ] - td_extractor._extract_metrics_for_all_fields = MagicMock(return_value=pd.DataFrame()) - updated_records = td_extractor.update_records(records) - assert updated_records == records - - -@pytest.mark.fixtures("td_extractor") -def test_update_feedback_dataset(td_extractor: TextDescriptivesExtractor): - dataset = FeedbackDataset( - fields=[TextField(name="text")], - questions=[TextQuestion(name="question")], - ) - records = [ - FeedbackRecord(fields={"text": "This is a test."}), - FeedbackRecord(fields={"text": "This is another test."}), - ] - dataset.add_records(records) - - extracted_metrics = pd.DataFrame({"text_n_tokens": [4, 5]}) - td_extractor._extract_metrics_for_all_fields = MagicMock(return_value=extracted_metrics) - td_extractor._cast_to_python_types = MagicMock(return_value=extracted_metrics) - td_extractor._clean_column_name = MagicMock(side_effect=lambda col: col) - td_extractor._create_metadata_properties = MagicMock(return_value=[IntegerMetadataProperty(name="text_n_tokens")]) - - updated_dataset = td_extractor.update_dataset(dataset, update_records=False) - - td_extractor._extract_metrics_for_all_fields.call_count == 1 - - updated_dataset = td_extractor.update_dataset(dataset, update_records=True) - - td_extractor._extract_metrics_for_all_fields.call_count == 3 - td_extractor._cast_to_python_types.assert_called_with(extracted_metrics) - td_extractor._clean_column_name.assert_called_with("text_n_tokens") - td_extractor._create_metadata_properties.assert_called_with(extracted_metrics) - assert updated_dataset == dataset - assert isinstance(updated_dataset, FeedbackDataset) - assert updated_dataset.metadata_properties == [ - IntegerMetadataProperty( - name="text_n_tokens", title="text_n_tokens", visible_for_annotators=True, type="integer", min=None, max=None - ) - ] - - -@pytest.mark.fixtures("td_extractor") -def test_update_dataset_with_invalid_fields(td_extractor: TextDescriptivesExtractor): - dataset = FeedbackDataset( - fields=[TextField(name="text")], - questions=[TextQuestion(name="question")], - ) - with pytest.raises(ValueError): - td_extractor.update_dataset(dataset, fields=["my_fake_field"]) diff --git a/argilla-v1/tests/unit/feedback/schemas/__init__.py b/argilla-v1/tests/unit/feedback/schemas/__init__.py deleted file mode 100644 index 55be41799..000000000 --- a/argilla-v1/tests/unit/feedback/schemas/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2021-present, the Recognai S.L. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/argilla-v1/tests/unit/feedback/schemas/remote/__init__.py b/argilla-v1/tests/unit/feedback/schemas/remote/__init__.py deleted file mode 100644 index 55be41799..000000000 --- a/argilla-v1/tests/unit/feedback/schemas/remote/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2021-present, the Recognai S.L. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/argilla-v1/tests/unit/feedback/schemas/remote/test_fields.py b/argilla-v1/tests/unit/feedback/schemas/remote/test_fields.py deleted file mode 100644 index 480ccbc4f..000000000 --- a/argilla-v1/tests/unit/feedback/schemas/remote/test_fields.py +++ /dev/null @@ -1,100 +0,0 @@ -# Copyright 2021-present, the Recognai S.L. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from datetime import datetime -from typing import Any, Dict -from uuid import uuid4 - -import pytest -from argilla_v1.client.feedback.schemas.enums import FieldTypes -from argilla_v1.client.feedback.schemas.fields import TextField -from argilla_v1.client.feedback.schemas.remote.fields import RemoteTextField -from argilla_v1.client.sdk.v1.datasets.models import FeedbackFieldModel - - -@pytest.mark.parametrize( - "schema_kwargs, server_payload", - [ - ( - {"name": "a"}, - {"name": "a", "title": "A", "required": True, "settings": {"type": "text", "use_markdown": False, "use_table": False}}, - ), - ( - {"name": "a", "title": "b"}, - {"name": "a", "title": "b", "required": True, "settings": {"type": "text", "use_markdown": False, "use_table": False}}, - ), - ( - {"name": "a", "title": "b", "required": False}, - {"name": "a", "title": "b", "required": False, "settings": {"type": "text", "use_markdown": False, "use_table": False}}, - ), - ( - {"name": "a", "title": "b", "required": False, "use_markdown": True}, - {"name": "a", "title": "b", "required": False, "settings": {"type": "text", "use_markdown": True, "use_table": False}}, - ), - ( - {"name": "a", "title": "b", "required": False, "use_table": True}, - {"name": "a", "title": "b", "required": False, "settings": {"type": "text", "use_markdown": False, "use_table": True}}, - ), - ], -) -def test_remote_text_field(schema_kwargs: Dict[str, Any], server_payload: Dict[str, Any]) -> None: - text_field = RemoteTextField(**schema_kwargs) - assert text_field.type == FieldTypes.text - assert text_field.server_settings == server_payload["settings"] - assert text_field.to_server_payload() == server_payload - - local_text_field = text_field.to_local() - assert isinstance(local_text_field, TextField) - assert local_text_field.type == FieldTypes.text - assert local_text_field.server_settings == server_payload["settings"] - assert local_text_field.to_server_payload() == server_payload - - -@pytest.mark.parametrize( - "payload", - [ - FeedbackFieldModel( - id=uuid4(), - name="a", - title="A", - required=True, - settings={"type": "text", "use_markdown": False, "use_table": False}, - inserted_at=datetime.now(), - updated_at=datetime.now(), - ), - FeedbackFieldModel( - id=uuid4(), - name="b", - title="B", - required=False, - settings={"type": "text", "use_markdown": True, "use_table": False}, - inserted_at=datetime.now(), - updated_at=datetime.now(), - ), - FeedbackFieldModel( - id=uuid4(), - name="b", - title="B", - required=False, - settings={"type": "text", "use_markdown": False, "use_table": True}, - inserted_at=datetime.now(), - updated_at=datetime.now(), - ), - ], -) -def test_remote_text_field_from_api(payload: FeedbackFieldModel) -> None: - text_field = RemoteTextField.from_api(payload) - assert text_field.type == FieldTypes.text - assert text_field.server_settings == payload.settings - assert text_field.to_server_payload() == payload.dict(exclude={"id", "inserted_at", "updated_at"}) diff --git a/argilla-v1/tests/unit/feedback/schemas/remote/test_metadata.py b/argilla-v1/tests/unit/feedback/schemas/remote/test_metadata.py deleted file mode 100644 index e30f7bcb5..000000000 --- a/argilla-v1/tests/unit/feedback/schemas/remote/test_metadata.py +++ /dev/null @@ -1,346 +0,0 @@ -# Copyright 2021-present, the Recognai S.L. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from datetime import datetime -from typing import Any, Dict -from uuid import uuid4 - -import pytest -from argilla_v1.client.feedback.schemas.enums import MetadataPropertyTypes -from argilla_v1.client.feedback.schemas.metadata import ( - FloatMetadataProperty, - IntegerMetadataProperty, - TermsMetadataProperty, -) -from argilla_v1.client.feedback.schemas.remote.metadata import ( - RemoteFloatMetadataProperty, - RemoteIntegerMetadataProperty, - RemoteTermsMetadataProperty, -) -from argilla_v1.client.sdk.v1.datasets.models import FeedbackMetadataPropertyModel - - -@pytest.mark.parametrize( - "schema_kwargs, server_payload", - [ - ( - {"name": "terms-metadata"}, - { - "name": "terms-metadata", - "title": "terms-metadata", - "visible_for_annotators": True, - "settings": {"type": "terms"}, - }, - ), - ( - {"name": "terms-metadata", "title": "alt-title"}, - { - "name": "terms-metadata", - "title": "alt-title", - "visible_for_annotators": True, - "settings": {"type": "terms"}, - }, - ), - ( - {"name": "terms-metadata", "visible_for_annotators": False}, - { - "name": "terms-metadata", - "title": "terms-metadata", - "visible_for_annotators": False, - "settings": {"type": "terms"}, - }, - ), - ( - {"name": "terms-metadata", "values": ["a"]}, - { - "name": "terms-metadata", - "title": "terms-metadata", - "visible_for_annotators": True, - "settings": {"type": "terms", "values": ["a"]}, - }, - ), - ( - {"name": "terms-metadata", "values": ["a", "b", "c"]}, - { - "name": "terms-metadata", - "title": "terms-metadata", - "visible_for_annotators": True, - "settings": {"type": "terms", "values": ["a", "b", "c"]}, - }, - ), - ], -) -def test_remote_terms_metadata_property(schema_kwargs: Dict[str, Any], server_payload: Dict[str, Any]) -> None: - text_field = RemoteTermsMetadataProperty(**schema_kwargs) - assert text_field.type == MetadataPropertyTypes.terms - assert text_field.server_settings == server_payload["settings"] - assert text_field.to_server_payload() == server_payload - - local_text_field = text_field.to_local() - assert isinstance(local_text_field, TermsMetadataProperty) - assert local_text_field.type == MetadataPropertyTypes.terms - assert local_text_field.server_settings == server_payload["settings"] - assert local_text_field.to_server_payload() == server_payload - - -@pytest.mark.parametrize( - "payload", - [ - FeedbackMetadataPropertyModel( - id=uuid4(), - name="terms-metadata", - title="alt-title", - visible_for_annotators=True, - settings={"type": "terms"}, - inserted_at=datetime.now(), - updated_at=datetime.now(), - ), - FeedbackMetadataPropertyModel( - id=uuid4(), - name="terms-metadata", - title="terms-metadata", - visible_for_annotators=False, - settings={"type": "terms", "values": ["a", "b", "c"]}, - inserted_at=datetime.now(), - updated_at=datetime.now(), - ), - ], -) -def test_remote_terms_metadata_property_from_api(payload: FeedbackMetadataPropertyModel) -> None: - text_field = RemoteTermsMetadataProperty.from_api(payload) - assert text_field.type == MetadataPropertyTypes.terms - assert text_field.server_settings == payload.settings - assert text_field.to_server_payload() == payload.dict(exclude={"id", "inserted_at", "updated_at"}) - - -@pytest.mark.parametrize( - "schema_kwargs, server_payload", - [ - ( - {"name": "int-metadata"}, - { - "name": "int-metadata", - "title": "int-metadata", - "visible_for_annotators": True, - "settings": {"type": "integer"}, - }, - ), - ( - {"name": "int-metadata", "title": "alt-title"}, - { - "name": "int-metadata", - "title": "alt-title", - "visible_for_annotators": True, - "settings": {"type": "integer"}, - }, - ), - ( - {"name": "int-metadata", "visible_for_annotators": False}, - { - "name": "int-metadata", - "title": "int-metadata", - "visible_for_annotators": False, - "settings": {"type": "integer"}, - }, - ), - ( - {"name": "int-metadata", "min": 0}, - { - "name": "int-metadata", - "title": "int-metadata", - "visible_for_annotators": True, - "settings": {"type": "integer", "min": 0}, - }, - ), - ( - {"name": "int-metadata", "max": 10}, - { - "name": "int-metadata", - "title": "int-metadata", - "visible_for_annotators": True, - "settings": {"type": "integer", "max": 10}, - }, - ), - ( - {"name": "int-metadata", "min": 0, "max": 10}, - { - "name": "int-metadata", - "title": "int-metadata", - "visible_for_annotators": True, - "settings": {"type": "integer", "min": 0, "max": 10}, - }, - ), - ], -) -def test_remote_integer_metadata_property(schema_kwargs: Dict[str, Any], server_payload: Dict[str, Any]) -> None: - text_field = RemoteIntegerMetadataProperty(**schema_kwargs) - assert text_field.type == MetadataPropertyTypes.integer - assert text_field.server_settings == server_payload["settings"] - assert text_field.to_server_payload() == server_payload - - local_text_field = text_field.to_local() - assert isinstance(local_text_field, IntegerMetadataProperty) - assert local_text_field.type == MetadataPropertyTypes.integer - assert local_text_field.server_settings == server_payload["settings"] - assert local_text_field.to_server_payload() == server_payload - - -@pytest.mark.parametrize( - "payload", - [ - FeedbackMetadataPropertyModel( - id=uuid4(), - name="int-metadata", - title="alt-title", - visible_for_annotators=True, - settings={"type": "integer"}, - inserted_at=datetime.now(), - updated_at=datetime.now(), - ), - FeedbackMetadataPropertyModel( - id=uuid4(), - name="int-metadata", - title="int-metadata", - visible_for_annotators=True, - settings={"type": "integer", "min": 0}, - inserted_at=datetime.now(), - updated_at=datetime.now(), - ), - FeedbackMetadataPropertyModel( - id=uuid4(), - name="int-metadata", - title="int-metadata", - visible_for_annotators=False, - settings={"type": "integer", "min": 0, "max": 10}, - inserted_at=datetime.now(), - updated_at=datetime.now(), - ), - ], -) -def test_remote_integer_metadata_property_from_api(payload: FeedbackMetadataPropertyModel) -> None: - text_field = RemoteIntegerMetadataProperty.from_api(payload) - assert text_field.type == MetadataPropertyTypes.integer - assert text_field.server_settings == payload.settings - assert text_field.to_server_payload() == payload.dict(exclude={"id", "inserted_at", "updated_at"}) - - -@pytest.mark.parametrize( - "schema_kwargs, server_payload", - [ - ( - {"name": "float-metadata"}, - { - "name": "float-metadata", - "title": "float-metadata", - "visible_for_annotators": True, - "settings": {"type": "float"}, - }, - ), - ( - {"name": "float-metadata", "title": "alt-title"}, - { - "name": "float-metadata", - "title": "alt-title", - "visible_for_annotators": True, - "settings": {"type": "float"}, - }, - ), - ( - {"name": "float-metadata", "visible_for_annotators": False}, - { - "name": "float-metadata", - "title": "float-metadata", - "visible_for_annotators": False, - "settings": {"type": "float"}, - }, - ), - ( - {"name": "float-metadata", "min": 0.0}, - { - "name": "float-metadata", - "title": "float-metadata", - "visible_for_annotators": True, - "settings": {"type": "float", "min": 0.0}, - }, - ), - ( - {"name": "float-metadata", "max": 10.0}, - { - "name": "float-metadata", - "title": "float-metadata", - "visible_for_annotators": True, - "settings": {"type": "float", "max": 10.0}, - }, - ), - ( - {"name": "float-metadata", "min": 0.0, "max": 10.0}, - { - "name": "float-metadata", - "title": "float-metadata", - "visible_for_annotators": True, - "settings": {"type": "float", "min": 0.0, "max": 10.0}, - }, - ), - ], -) -def test_remote_float_metadata_property(schema_kwargs: Dict[str, Any], server_payload: Dict[str, Any]) -> None: - text_field = RemoteFloatMetadataProperty(**schema_kwargs) - assert text_field.type == MetadataPropertyTypes.float - assert text_field.server_settings == server_payload["settings"] - assert text_field.to_server_payload() == server_payload - - local_text_field = text_field.to_local() - assert isinstance(local_text_field, FloatMetadataProperty) - assert local_text_field.type == MetadataPropertyTypes.float - assert local_text_field.server_settings == server_payload["settings"] - assert local_text_field.to_server_payload() == server_payload - - -@pytest.mark.parametrize( - "payload", - [ - FeedbackMetadataPropertyModel( - id=uuid4(), - name="float-metadata", - title="alt-title", - visible_for_annotators=True, - settings={"type": "float"}, - inserted_at=datetime.now(), - updated_at=datetime.now(), - ), - FeedbackMetadataPropertyModel( - id=uuid4(), - name="float-metadata", - title="float-metadata", - visible_for_annotators=True, - settings={"type": "float", "min": 0.0}, - inserted_at=datetime.now(), - updated_at=datetime.now(), - ), - FeedbackMetadataPropertyModel( - id=uuid4(), - name="float-metadata", - title="float-metadata", - visible_for_annotators=False, - settings={"type": "float", "min": 0.0, "max": 10.0}, - inserted_at=datetime.now(), - updated_at=datetime.now(), - ), - ], -) -def test_remote_float_metadata_property_from_api(payload: FeedbackMetadataPropertyModel) -> None: - text_field = RemoteFloatMetadataProperty.from_api(payload) - assert text_field.type == MetadataPropertyTypes.float - assert text_field.server_settings == payload.settings - assert text_field.to_server_payload() == payload.dict(exclude={"id", "inserted_at", "updated_at"}) diff --git a/argilla-v1/tests/unit/feedback/schemas/remote/test_questions.py b/argilla-v1/tests/unit/feedback/schemas/remote/test_questions.py deleted file mode 100644 index 8b20b4922..000000000 --- a/argilla-v1/tests/unit/feedback/schemas/remote/test_questions.py +++ /dev/null @@ -1,535 +0,0 @@ -# Copyright 2021-present, the Recognai S.L. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from datetime import datetime -from typing import Any, Dict -from uuid import uuid4 - -import pytest -from argilla_v1.client.feedback.schemas.enums import LabelsOrder, QuestionTypes -from argilla_v1.client.feedback.schemas.questions import ( - LabelQuestion, - MultiLabelQuestion, - RankingQuestion, - RatingQuestion, - TextQuestion, -) -from argilla_v1.client.feedback.schemas.remote.questions import ( - RemoteLabelQuestion, - RemoteMultiLabelQuestion, - RemoteRankingQuestion, - RemoteRatingQuestion, - RemoteSpanQuestion, - RemoteTextQuestion, -) -from argilla_v1.client.sdk.v1.datasets.models import FeedbackQuestionModel - - -@pytest.mark.parametrize( - "schema_kwargs, server_payload", - [ - ( - {"name": "a", "required": True, "use_markdown": True}, - { - "name": "a", - "title": "A", - "description": None, - "required": True, - "settings": {"type": "text", "use_markdown": True, "use_table": False}, - }, - ), - ( - {"name": "a", "title": "B", "description": "b", "required": False, "use_markdown": False}, - { - "name": "a", - "title": "B", - "description": "b", - "required": False, - "settings": {"type": "text", "use_markdown": False, "use_table": False}, - }, - ), - ( - {"name": "a", "title": "B", "description": "b", "required": False, "use_table": True}, - { - "name": "a", - "title": "B", - "description": "b", - "required": False, - "settings": {"type": "text", "use_markdown": False, "use_table": True}, - }, - ), - ], -) -def test_remote_text_question(schema_kwargs: Dict[str, Any], server_payload: Dict[str, Any]) -> None: - text_question = RemoteTextQuestion(**schema_kwargs) - assert text_question.type == QuestionTypes.text - assert text_question.server_settings == server_payload["settings"] - assert text_question.to_server_payload() == server_payload - - local_text_question = text_question.to_local() - assert isinstance(local_text_question, TextQuestion) - assert local_text_question.type == QuestionTypes.text - assert local_text_question.server_settings == server_payload["settings"] - assert local_text_question.to_server_payload() == server_payload - - -@pytest.mark.parametrize( - "payload", - [ - FeedbackQuestionModel( - id=uuid4(), - name="a", - title="A", - description="Description", - required=True, - settings={"type": "text", "use_markdown": False, "use_table": False}, - inserted_at=datetime.now(), - updated_at=datetime.now(), - ), - FeedbackQuestionModel( - id=uuid4(), - name="b", - title="B", - description="Description", - required=False, - settings={"type": "text", "use_markdown": True, "use_table": False}, - inserted_at=datetime.now(), - updated_at=datetime.now(), - ), - FeedbackQuestionModel( - id=uuid4(), - name="b", - title="B", - description="Description", - required=False, - settings={"type": "text", "use_markdown": False, "use_table": True}, - inserted_at=datetime.now(), - updated_at=datetime.now(), - ), - ], -) -def test_remote_text_question_from_api(payload: FeedbackQuestionModel) -> None: - text_question = RemoteTextQuestion.from_api(payload) - assert text_question.type == QuestionTypes.text - assert text_question.server_settings == payload.settings - assert text_question.to_server_payload() == payload.dict(exclude={"id", "inserted_at", "updated_at"}) - - -@pytest.mark.parametrize( - "schema_kwargs, server_payload", - [ - ( - {"name": "a", "values": [1, 2, 3]}, - { - "name": "a", - "title": "A", - "description": None, - "required": True, - "settings": {"type": "rating", "options": [{"value": 1}, {"value": 2}, {"value": 3}]}, - }, - ), - ( - {"name": "a", "title": "B", "description": "b", "required": False, "values": [1, 2, 3]}, - { - "name": "a", - "title": "B", - "description": "b", - "required": False, - "settings": {"type": "rating", "options": [{"value": 1}, {"value": 2}, {"value": 3}]}, - }, - ), - ], -) -def test_remote_rating_question(schema_kwargs: Dict[str, Any], server_payload: Dict[str, Any]) -> None: - rating_question = RemoteRatingQuestion(**schema_kwargs) - assert rating_question.type == QuestionTypes.rating - assert rating_question.server_settings == server_payload["settings"] - assert rating_question.to_server_payload() == server_payload - - local_rating_question = rating_question.to_local() - assert isinstance(local_rating_question, RatingQuestion) - assert local_rating_question.type == QuestionTypes.rating - assert local_rating_question.server_settings == server_payload["settings"] - assert local_rating_question.to_server_payload() == server_payload - - -@pytest.mark.parametrize( - "payload", - [ - FeedbackQuestionModel( - id=uuid4(), - name="a", - title="A", - description="Description", - required=True, - settings={"type": "rating", "options": [{"value": 1}, {"value": 2}, {"value": 3}]}, - inserted_at=datetime.now(), - updated_at=datetime.now(), - ), - FeedbackQuestionModel( - id=uuid4(), - name="b", - title="B", - description="Description", - required=False, - settings={"type": "rating", "options": [{"value": 1}, {"value": 2}, {"value": 3}]}, - inserted_at=datetime.now(), - updated_at=datetime.now(), - ), - ], -) -def test_remote_rating_question_from_api(payload: FeedbackQuestionModel) -> None: - rating_question = RemoteRatingQuestion.from_api(payload) - assert rating_question.type == QuestionTypes.rating - assert rating_question.server_settings == payload.settings - assert rating_question.to_server_payload() == payload.dict(exclude={"id", "inserted_at", "updated_at"}) - - -@pytest.mark.parametrize( - "schema_kwargs, server_payload", - [ - ( - {"name": "a", "labels": ["a", "b", "c"]}, - { - "name": "a", - "title": "A", - "description": None, - "required": True, - "settings": { - "type": "label_selection", - "options": [{"text": "a", "value": "a"}, {"text": "b", "value": "b"}, {"text": "c", "value": "c"}], - "visible_options": None, - }, - }, - ), - ( - { - "name": "a", - "title": "B", - "description": "b", - "required": False, - "labels": {"a": "A", "b": "B", "c": "C"}, - "visible_labels": 3, - }, - { - "name": "a", - "title": "B", - "description": "b", - "required": False, - "settings": { - "type": "label_selection", - "options": [{"text": "A", "value": "a"}, {"text": "B", "value": "b"}, {"text": "C", "value": "c"}], - "visible_options": 3, - }, - }, - ), - ], -) -def test_remote_label_question(schema_kwargs: Dict[str, Any], server_payload: Dict[str, Any]) -> None: - label_question = RemoteLabelQuestion(**schema_kwargs) - assert label_question.type == QuestionTypes.label_selection - assert label_question.server_settings == server_payload["settings"] - assert label_question.to_server_payload() == server_payload - - local_label_question = label_question.to_local() - assert isinstance(local_label_question, LabelQuestion) - assert local_label_question.type == QuestionTypes.label_selection - assert local_label_question.server_settings == server_payload["settings"] - assert local_label_question.to_server_payload() == server_payload - - -@pytest.mark.parametrize( - "payload", - [ - FeedbackQuestionModel( - id=uuid4(), - name="a", - title="A", - required=True, - description="Description", - settings={ - "type": "label_selection", - "options": [{"text": "a", "value": "a"}, {"text": "b", "value": "b"}, {"text": "c", "value": "c"}], - "visible_options": None, - }, - inserted_at=datetime.now(), - updated_at=datetime.now(), - ), - FeedbackQuestionModel( - id=uuid4(), - name="b", - title="B", - description="Description", - required=False, - settings={ - "type": "label_selection", - "options": [{"text": "A", "value": "a"}, {"text": "B", "value": "b"}, {"text": "C", "value": "c"}], - "visible_options": 3, - }, - inserted_at=datetime.now(), - updated_at=datetime.now(), - ), - ], -) -def test_remote_label_question_from_api(payload: FeedbackQuestionModel) -> None: - label_question = RemoteLabelQuestion.from_api(payload) - assert label_question.type == QuestionTypes.label_selection - assert label_question.server_settings == payload.settings - assert label_question.to_server_payload() == payload.dict(exclude={"id", "inserted_at", "updated_at"}) - - -@pytest.mark.parametrize( - "schema_kwargs, server_payload", - [ - ( - {"name": "a", "labels": ["a", "b", "c"]}, - { - "name": "a", - "title": "A", - "description": None, - "required": True, - "settings": { - "type": "multi_label_selection", - "options": [{"text": "a", "value": "a"}, {"text": "b", "value": "b"}, {"text": "c", "value": "c"}], - "visible_options": None, - "options_order": LabelsOrder.natural, - }, - }, - ), - ( - { - "name": "a", - "title": "B", - "description": "b", - "required": False, - "labels": {"a": "A", "b": "B", "c": "C"}, - "visible_labels": 3, - "labels_order": LabelsOrder.suggestion, - }, - { - "name": "a", - "title": "B", - "description": "b", - "required": False, - "settings": { - "type": "multi_label_selection", - "options": [{"text": "A", "value": "a"}, {"text": "B", "value": "b"}, {"text": "C", "value": "c"}], - "visible_options": 3, - "options_order": LabelsOrder.suggestion, - }, - }, - ), - ], -) -def test_remote_multi_label_question(schema_kwargs: Dict[str, Any], server_payload: Dict[str, Any]) -> None: - multi_label_question = RemoteMultiLabelQuestion(**schema_kwargs) - assert multi_label_question.type == QuestionTypes.multi_label_selection - assert multi_label_question.server_settings == server_payload["settings"] - assert multi_label_question.to_server_payload() == server_payload - - local_multi_label_question = multi_label_question.to_local() - assert isinstance(local_multi_label_question, MultiLabelQuestion) - assert local_multi_label_question.type == QuestionTypes.multi_label_selection - assert local_multi_label_question.server_settings == server_payload["settings"] - assert local_multi_label_question.to_server_payload() == server_payload - - -@pytest.mark.parametrize( - "payload", - [ - FeedbackQuestionModel( - id=uuid4(), - name="a", - title="A", - description="Description", - required=True, - settings={ - "type": "multi_label_selection", - "options": [{"text": "a", "value": "a"}, {"text": "b", "value": "b"}, {"text": "c", "value": "c"}], - "visible_options": None, - "options_order": LabelsOrder.natural, - }, - inserted_at=datetime.now(), - updated_at=datetime.now(), - ), - FeedbackQuestionModel( - id=uuid4(), - name="b", - title="B", - description="Description", - required=False, - settings={ - "type": "multi_label_selection", - "options": [{"text": "A", "value": "a"}, {"text": "B", "value": "b"}, {"text": "C", "value": "c"}], - "visible_options": 3, - "options_order": LabelsOrder.suggestion, - }, - inserted_at=datetime.now(), - updated_at=datetime.now(), - ), - ], -) -def test_remote_multi_label_question_from_api(payload: FeedbackQuestionModel) -> None: - multi_label_question = RemoteMultiLabelQuestion.from_api(payload) - assert multi_label_question.type == QuestionTypes.multi_label_selection - assert multi_label_question.server_settings == payload.settings - assert multi_label_question.to_server_payload() == payload.dict(exclude={"id", "inserted_at", "updated_at"}) - - -@pytest.mark.parametrize( - "schema_kwargs, server_payload", - [ - ( - {"name": "a", "values": ["a", "b", "c"]}, - { - "name": "a", - "title": "A", - "description": None, - "required": True, - "settings": { - "type": "ranking", - "options": [{"text": "a", "value": "a"}, {"text": "b", "value": "b"}, {"text": "c", "value": "c"}], - }, - }, - ), - ( - { - "name": "a", - "title": "B", - "description": "b", - "required": False, - "values": {"a": "A", "b": "B", "c": "C"}, - }, - { - "name": "a", - "title": "B", - "description": "b", - "required": False, - "settings": { - "type": "ranking", - "options": [{"text": "A", "value": "a"}, {"text": "B", "value": "b"}, {"text": "C", "value": "c"}], - }, - }, - ), - ], -) -def test_remote_ranking_question(schema_kwargs: Dict[str, Any], server_payload: Dict[str, Any]) -> None: - ranking_question = RemoteRankingQuestion(**schema_kwargs) - assert ranking_question.type == QuestionTypes.ranking - assert ranking_question.server_settings == server_payload["settings"] - assert ranking_question.to_server_payload() == server_payload - - local_ranking_question = ranking_question.to_local() - assert isinstance(local_ranking_question, RankingQuestion) - assert local_ranking_question.type == QuestionTypes.ranking - assert local_ranking_question.server_settings == server_payload["settings"] - assert local_ranking_question.to_server_payload() == server_payload - - -@pytest.mark.parametrize( - "payload", - [ - FeedbackQuestionModel( - id=uuid4(), - name="a", - title="A", - description="Description", - required=True, - settings={ - "type": "ranking", - "options": [{"text": "a", "value": "a"}, {"text": "b", "value": "b"}, {"text": "c", "value": "c"}], - }, - inserted_at=datetime.now(), - updated_at=datetime.now(), - ), - FeedbackQuestionModel( - id=uuid4(), - name="b", - title="B", - description="Description", - required=False, - settings={ - "type": "ranking", - "options": [{"text": "A", "value": "a"}, {"text": "B", "value": "b"}, {"text": "C", "value": "c"}], - }, - inserted_at=datetime.now(), - updated_at=datetime.now(), - ), - ], -) -def test_remote_ranking_question_from_api(payload: FeedbackQuestionModel) -> None: - ranking_question = RemoteRankingQuestion.from_api(payload) - assert ranking_question.type == QuestionTypes.ranking - assert ranking_question.server_settings == payload.settings - assert ranking_question.to_server_payload() == payload.dict(exclude={"id", "inserted_at", "updated_at"}) - - -def test_span_questions_from_api(): - model = FeedbackQuestionModel( - id=uuid4(), - name="question", - title="Question", - required=True, - settings={ - "type": "span", - "field": "field", - "visible_options": None, - "allow_overlapping": False, - "options": [ - {"text": "Span label a", "value": "a", "description": None}, - { - "text": "Span label b", - "value": "b", - "description": None, - }, - ], - }, - inserted_at=datetime.now(), - updated_at=datetime.now(), - ) - question = RemoteSpanQuestion.from_api(model) - - assert question.type == QuestionTypes.span - assert question.server_settings == model.settings - assert question.to_server_payload() == model.dict(exclude={"id", "inserted_at", "updated_at"}) - assert question.to_local().type == QuestionTypes.span - - -def test_span_questions_from_api_with_visible_labels(): - model = FeedbackQuestionModel( - id=uuid4(), - name="question", - title="Question", - required=True, - settings={ - "type": "span", - "field": "field", - "visible_options": 3, - "allow_overlapping": False, - "options": [ - {"text": "Span label a", "value": "a", "description": None}, - {"text": "Span label b", "value": "b", "description": None}, - {"text": "Span label c", "value": "c", "description": None}, - {"text": "Span label d", "value": "d", "description": None}, - ], - }, - inserted_at=datetime.now(), - updated_at=datetime.now(), - ) - question = RemoteSpanQuestion.from_api(model) - - assert question.type == QuestionTypes.span - assert question.server_settings == model.settings - assert question.to_server_payload() == model.dict(exclude={"id", "inserted_at", "updated_at"}) - assert question.to_local().type == QuestionTypes.span diff --git a/argilla-v1/tests/unit/feedback/schemas/remote/test_records.py b/argilla-v1/tests/unit/feedback/schemas/remote/test_records.py deleted file mode 100644 index 9ca24088f..000000000 --- a/argilla-v1/tests/unit/feedback/schemas/remote/test_records.py +++ /dev/null @@ -1,381 +0,0 @@ -# Copyright 2021-present, the Recognai S.L. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from datetime import datetime -from typing import Any, Dict -from uuid import UUID, uuid4 - -import pytest -from argilla_v1.client.feedback.schemas.records import FeedbackRecord, ResponseSchema, SuggestionSchema -from argilla_v1.client.feedback.schemas.remote.records import ( - RemoteFeedbackRecord, - RemoteResponseSchema, - RemoteSuggestionSchema, -) -from argilla_v1.client.sdk.v1.datasets.models import ( - FeedbackItemModel, - FeedbackRankingValueModel, - FeedbackResponseModel, - FeedbackSuggestionModel, - FeedbackValueModel, -) - - -@pytest.mark.parametrize( - "schema_kwargs, server_payload", - [ - ( - { - "question_id": UUID("00000000-0000-0000-0000-000000000000"), - "question_name": "question-1", - "type": "human", - "score": 0.5, - "value": "a", - "agent": "b", - }, - { - "question_id": "00000000-0000-0000-0000-000000000000", - "type": "human", - "score": 0.5, - "value": "a", - "agent": "b", - }, - ), - ( - { - "question_id": UUID("00000000-0000-0000-0000-000000000000"), - "question_name": "question-1", - "type": "model", - "score": 1.0, - "value": "a", - "agent": "b", - }, - { - "question_id": "00000000-0000-0000-0000-000000000000", - "type": "model", - "score": 1.0, - "value": "a", - "agent": "b", - }, - ), - ], -) -def test_remote_suggestion_schema(schema_kwargs: Dict[str, Any], server_payload: Dict[str, Any]) -> None: - suggestion = RemoteSuggestionSchema(**schema_kwargs) - assert ( - suggestion.to_server_payload(question_name_to_id={schema_kwargs["question_name"]: schema_kwargs["question_id"]}) - == server_payload - ) - - local_suggestion = suggestion.to_local() - assert isinstance(local_suggestion, SuggestionSchema) - assert ( - local_suggestion.to_server_payload( - question_name_to_id={schema_kwargs["question_name"]: schema_kwargs["question_id"]} - ) - == server_payload - ) - - -@pytest.mark.parametrize( - "payload", - [ - FeedbackSuggestionModel( - id=uuid4(), - question_id=str(uuid4()), - type="human", - score=0.5, - value="a", - agent="b", - ), - FeedbackSuggestionModel( - id=uuid4(), - question_id=str(uuid4()), - type="model", - score=1.0, - value="a", - agent="b", - ), - ], -) -def test_remote_suggestion_schema_from_api(payload: FeedbackSuggestionModel) -> None: - suggestion = RemoteSuggestionSchema.from_api(payload, question_id_to_name={UUID(payload.question_id): "question-1"}) - assert suggestion.to_server_payload(question_name_to_id={"question-1": payload.question_id}) == payload.dict( - exclude={"id"} - ) - - -@pytest.mark.parametrize( - "schema_kwargs, server_payload", - [ - ( - { - "user_id": UUID("00000000-0000-0000-0000-000000000000"), - "values": { - "question-1": {"value": "a"}, - "question-2": {"value": 1}, - "question-3": {"value": ["a", "b"]}, - "question-4": {"value": [{"value": "a", "rank": 1}, {"value": "b", "rank": 2}]}, - }, - "status": "submitted", - "inserted_at": datetime.now(), - "updated_at": datetime.now(), - }, - { - "user_id": UUID("00000000-0000-0000-0000-000000000000"), - "values": { - "question-1": {"value": "a"}, - "question-2": {"value": 1}, - "question-3": {"value": ["a", "b"]}, - "question-4": {"value": [{"value": "a", "rank": 1}, {"value": "b", "rank": 2}]}, - }, - "status": "submitted", - }, - ), - ( - { - "user_id": UUID("00000000-0000-0000-0000-000000000000"), - "values": {"question-1": {"value": "a"}}, - "status": "draft", - "inserted_at": datetime.now(), - "updated_at": datetime.now(), - }, - { - "user_id": UUID("00000000-0000-0000-0000-000000000000"), - "values": {"question-1": {"value": "a"}}, - "status": "draft", - }, - ), - ( - { - "user_id": UUID("00000000-0000-0000-0000-000000000000"), - "values": None, - "status": "discarded", - "inserted_at": datetime.now(), - "updated_at": datetime.now(), - }, - { - "user_id": UUID("00000000-0000-0000-0000-000000000000"), - "values": None, - "status": "discarded", - }, - ), - ], -) -def test_remote_response_schema(schema_kwargs: Dict[str, Any], server_payload: Dict[str, Any]) -> None: - response = RemoteResponseSchema(**schema_kwargs) - assert response.to_server_payload() == server_payload - - local_response = response.to_local() - assert isinstance(local_response, ResponseSchema) - assert local_response.to_server_payload() == server_payload - - -@pytest.mark.parametrize( - "payload", - [ - FeedbackResponseModel( - id=uuid4(), - values={ - "question-1": FeedbackValueModel(value="a"), - "question-2": FeedbackValueModel(value=1), - "question-3": FeedbackValueModel(value=["a", "b"]), - "question-4": FeedbackValueModel( - value=[FeedbackRankingValueModel(value="a", rank=1), FeedbackRankingValueModel(value="b", rank=2)] - ), - "question-5": FeedbackValueModel(value=[{"start": 0, "end": 1, "label": "a"}]), - }, - status="submitted", - user_id=uuid4(), - inserted_at=datetime.now(), - updated_at=datetime.now(), - ), - FeedbackResponseModel( - id=uuid4(), - values={"question-1": FeedbackValueModel(value="a")}, - status="draft", - user_id=uuid4(), - inserted_at=datetime.now(), - updated_at=datetime.now(), - ), - FeedbackResponseModel( - id=uuid4(), - values={"span-question": FeedbackValueModel(value=[{"start": 0, "end": 1, "label": "a"}])}, - status="discarded", - user_id=uuid4(), - inserted_at=datetime.now(), - updated_at=datetime.now(), - ), - ], -) -def test_remote_response_schema_from_api(payload: FeedbackResponseModel) -> None: - response = RemoteResponseSchema.from_api(payload) - assert response.to_server_payload() == payload.dict(exclude={"id", "inserted_at", "updated_at"}) - - local_response = response.to_local() - assert isinstance(local_response, ResponseSchema) - assert local_response.to_server_payload() == payload.dict(exclude={"id", "inserted_at", "updated_at"}) - - -@pytest.mark.parametrize( - "schema_kwargs, server_payload", - [ - ( - { - "id": UUID("00000000-0000-0000-0000-000000000000"), - "fields": {"text": "This is the first record", "label": "positive", "optional": None}, - "metadata": {"first": True, "nested": {"more": "stuff"}}, - "responses": [ - { - "values": {"question-1": {"value": "This is the first answer"}, "question-2": {"value": 5}}, - "status": "submitted", - "inserted_at": datetime.now(), - "updated_at": datetime.now(), - }, - ], - "suggestions": [ - { - "question_id": UUID("00000000-0000-0000-0000-000000000000"), - "question_name": "question-1", - "type": "model", - "score": 0.9, - "value": "This is the first suggestion", - "agent": "agent-1", - }, - ], - "vectors": { - "vector-1": [1.0, 2.0, 3.0], - "vector-2": [1.0, 2.0, 3.0, 4.0], - }, - "external_id": "entry-1", - }, - { - "fields": {"text": "This is the first record", "label": "positive"}, - "metadata": {"first": True, "nested": {"more": "stuff"}}, - "responses": [ - { - "user_id": None, - "values": {"question-1": {"value": "This is the first answer"}, "question-2": {"value": 5}}, - "status": "submitted", - }, - ], - "suggestions": [ - { - "question_id": "00000000-0000-0000-0000-000000000000", - "type": "model", - "score": 0.9, - "value": "This is the first suggestion", - "agent": "agent-1", - }, - ], - "vectors": { - "vector-1": [1.0, 2.0, 3.0], - "vector-2": [1.0, 2.0, 3.0, 4.0], - }, - "external_id": "entry-1", - }, - ), - ], -) -def test_remote_feedback_record(schema_kwargs: Dict[str, Any], server_payload: Dict[str, Any]) -> None: - record = RemoteFeedbackRecord( - **schema_kwargs, question_name_to_id={"question-1": UUID("00000000-0000-0000-0000-000000000000")} - ) - assert ( - record.to_server_payload(question_name_to_id={"question-1": UUID("00000000-0000-0000-0000-000000000000")}) - == server_payload - ) - - local_record = record.to_local() - assert isinstance(local_record, FeedbackRecord) - assert ( - local_record.to_server_payload(question_name_to_id={"question-1": UUID("00000000-0000-0000-0000-000000000000")}) - == server_payload - ) - - -@pytest.mark.parametrize( - "payload", - [ - FeedbackItemModel( - id=uuid4(), - fields={"text": "This is the first record", "label": "positive"}, - metadata={"first": True, "nested": {"more": "stuff"}}, - external_id="entry-1", - responses=[ - FeedbackResponseModel( - id=uuid4(), - values={ - "question-1": FeedbackValueModel(value="This is the first answer"), - }, - status="submitted", - user_id=uuid4(), - inserted_at=datetime.now(), - updated_at=datetime.now(), - ), - ], - suggestions=[ - FeedbackSuggestionModel( - id=uuid4(), - question_id=str(uuid4()), - type="model", - score=0.9, - value="This is the first suggestion", - agent="agent-1", - ) - ], - vectors={ - "vector-1": [1.0, 2.0, 3.0], - "vector-2": [1.0, 2.0, 3.0, 4.0], - }, - inserted_at=datetime.now(), - updated_at=datetime.now(), - ), - ], -) -def test_remote_feedback_record_schema_from_api(payload: FeedbackItemModel) -> None: - record = RemoteFeedbackRecord.from_api( - payload, question_id_to_name={UUID(payload.suggestions[0].question_id): "question-1"} - ) - # Skipping `suggestions` temporarily as it's now a tuple internally formatted and the type is not preserved - assert record.dict( - exclude={ - "client": ..., - "responses": {"__all__": {"id", "client"}}, - "suggestions": ..., - "inserted_at": ..., - "updated_at": ..., - } - ) == payload.dict( - exclude={ - "responses": {"__all__": {"id"}}, - "suggestions": ..., - "inserted_at": ..., - "updated_at": ..., - } - ) - - local_record = record.to_local() - assert isinstance(local_record, FeedbackRecord) - assert local_record.to_server_payload( - question_name_to_id={"question-1": payload.suggestions[0].question_id} - ) == payload.dict( - exclude={ - "id": ..., - "responses": {"__all__": {"id", "inserted_at", "updated_at"}}, - "suggestions": {"__all__": {"id"}}, - "inserted_at": ..., - "updated_at": ..., - } - ) diff --git a/argilla-v1/tests/unit/feedback/schemas/test_fields.py b/argilla-v1/tests/unit/feedback/schemas/test_fields.py deleted file mode 100644 index 0f917ac0c..000000000 --- a/argilla-v1/tests/unit/feedback/schemas/test_fields.py +++ /dev/null @@ -1,67 +0,0 @@ -# Copyright 2021-present, the Recognai S.L. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any, Dict - -import pytest -from argilla_v1.client.feedback.schemas.enums import FieldTypes -from argilla_v1.client.feedback.schemas.fields import TextField - -from tests.pydantic_v1 import ValidationError - - -@pytest.mark.parametrize( - "schema_kwargs, server_payload", - [ - ( - {"name": "a"}, - {"name": "a", "title": "A", "required": True, "settings": {"type": "text", "use_markdown": False, "use_table": False}}, - ), - ( - {"name": "a", "title": "b"}, - {"name": "a", "title": "b", "required": True, "settings": {"type": "text", "use_markdown": False, "use_table": False}}, - ), - ( - {"name": "a", "title": "b", "required": False}, - {"name": "a", "title": "b", "required": False, "settings": {"type": "text", "use_markdown": False, "use_table": False}}, - ), - ( - {"name": "a", "title": "b", "required": False, "use_markdown": True}, - {"name": "a", "title": "b", "required": False, "settings": {"type": "text", "use_markdown": True, "use_table": False}}, - ), - ( - {"name": "a", "title": "b", "required": False, "use_table": True}, - {"name": "a", "title": "b", "required": False, "settings": {"type": "text", "use_markdown": False, "use_table": True}}, - ), - ], -) -def test_text_field(schema_kwargs: Dict[str, Any], server_payload: Dict[str, Any]) -> None: - text_field = TextField(**schema_kwargs) - assert text_field.type == FieldTypes.text - assert text_field.server_settings == server_payload["settings"] - assert text_field.to_server_payload() == server_payload - - -@pytest.mark.parametrize( - "schema_kwargs, exception_cls, exception_message", - [ - ({"name": "a b"}, ValidationError, "name\n string does not match regex"), - ({}, ValidationError, "name\n field required"), - # The test case below won't match the full regex, as it will assume the type is QuestionType.text instead, only God knows why - ({"name": "a", "type": "other"}, ValidationError, "type\n unexpected value; permitted:"), - ], -) -def test_text_field_errors(schema_kwargs: Dict[str, Any], exception_cls: Any, exception_message: str) -> None: - with pytest.raises(exception_cls, match=exception_message): - TextField(**schema_kwargs) diff --git a/argilla-v1/tests/unit/feedback/schemas/test_metadata.py b/argilla-v1/tests/unit/feedback/schemas/test_metadata.py deleted file mode 100644 index 75473893b..000000000 --- a/argilla-v1/tests/unit/feedback/schemas/test_metadata.py +++ /dev/null @@ -1,427 +0,0 @@ -# Copyright 2021-present, the Recognai S.L. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any, Dict - -import pytest -from argilla_v1.client.feedback.schemas.enums import MetadataPropertyTypes -from argilla_v1.client.feedback.schemas.metadata import ( - FloatMetadataFilter, - FloatMetadataProperty, - IntegerMetadataFilter, - IntegerMetadataProperty, - TermsMetadataFilter, - TermsMetadataProperty, -) - -from tests.pydantic_v1 import ValidationError, create_model - - -@pytest.mark.parametrize( - "schema_kwargs, server_payload, metadata_filter, metadata_property_to_validate", - [ - ( - {"name": "terms-metadata", "title": "alt-title"}, - { - "name": "terms-metadata", - "title": "alt-title", - "visible_for_annotators": True, - "settings": {"type": "terms"}, - }, - TermsMetadataFilter(name="terms-metadata", values=["a", "b", "c"]), - {"terms-metadata": "a"}, - ), - ( - {"name": "terms-metadata", "values": ["a", "b", "c"]}, - { - "name": "terms-metadata", - "title": "terms-metadata", - "visible_for_annotators": True, - "settings": {"type": "terms", "values": ["a", "b", "c"]}, - }, - TermsMetadataFilter(name="terms-metadata", values=["a", "b", "c"]), - {"terms-metadata": "a"}, - ), - ( - { - "name": "terms-metadata", - "visible_for_annotators": False, - "values": ["a", "b", "c"], - }, - { - "name": "terms-metadata", - "title": "terms-metadata", - "visible_for_annotators": False, - "settings": {"type": "terms", "values": ["a", "b", "c"]}, - }, - TermsMetadataFilter(name="terms-metadata", values=["a", "b", "c"]), - {"terms-metadata": "a"}, - ), - ], -) -def test_terms_metadata_property( - schema_kwargs: Dict[str, Any], - server_payload: Dict[str, Any], - metadata_filter: TermsMetadataFilter, - metadata_property_to_validate: Dict[str, str], -) -> None: - metadata_property = TermsMetadataProperty(**schema_kwargs) - assert metadata_property.type == MetadataPropertyTypes.terms - assert metadata_property.server_settings == server_payload["settings"] - assert metadata_property.to_server_payload() == server_payload - - metadata_property._validate_filter(metadata_filter=metadata_filter) # ValidationError not raised - - metadata_field, metadata_validator = metadata_property._pydantic_field_with_validator - MetadataModel = create_model("MetadataModel", **metadata_field, __validators__=metadata_validator) - assert MetadataModel(**metadata_property_to_validate) - - -@pytest.mark.parametrize( - "schema_kwargs, exception_cls, exception_message", - [ - ({"name": "a b"}, ValidationError, "name\n string does not match regex"), - ( - {"name": "terms-metadata-property", "values": []}, - ValidationError, - "1 validation error for TermsMetadataProperty\nvalues\n `TermsMetadataProperty` with name=terms-metadata-property must have at least 1 `values`", - ), - ( - {"name": "terms-metadata-property", "values": ["a", "a"]}, - ValidationError, - "1 validation error for TermsMetadataProperty\nvalues\n `TermsMetadataProperty` with name=terms-metadata-property cannot have repeated `values`", - ), - ( - {"name": "int-metadata-property", "extra-arg": "a"}, - ValidationError, - "1 validation error for TermsMetadataProperty\nextra-arg\n extra fields not permitted \(type=value_error.extra\)", - ), - ], -) -def test_terms_metadata_property_errors( - schema_kwargs: Dict[str, Any], exception_cls: Any, exception_message: str -) -> None: - with pytest.raises(exception_cls, match=exception_message): - TermsMetadataProperty(**schema_kwargs) - - -@pytest.mark.parametrize( - "schema_kwargs, server_payload, metadata_filter, metadata_property_to_validate", - [ - ( - {"name": "int-metadata", "title": "alt-title"}, - { - "name": "int-metadata", - "title": "alt-title", - "visible_for_annotators": True, - "settings": {"type": "integer"}, - }, - IntegerMetadataFilter(name="int-metadata", le=10, ge=5), - {"int-metadata": 7}, - ), - ( - { - "name": "int-metadata", - "visible_for_annotators": False, - "max": 5, - }, - { - "name": "int-metadata", - "title": "int-metadata", - "visible_for_annotators": False, - "settings": {"type": "integer", "max": 5}, - }, - IntegerMetadataFilter(name="int-metadata", le=5, ge=0), - {"int-metadata": 3}, - ), - ( - {"name": "int-metadata", "min": 5}, - { - "name": "int-metadata", - "title": "int-metadata", - "visible_for_annotators": True, - "settings": {"type": "integer", "min": 5}, - }, - IntegerMetadataFilter(name="int-metadata", le=10, ge=5), - {"int-metadata": 7}, - ), - ( - {"name": "int-metadata", "min": 5, "max": 10}, - { - "name": "int-metadata", - "title": "int-metadata", - "visible_for_annotators": True, - "settings": {"type": "integer", "min": 5, "max": 10}, - }, - IntegerMetadataFilter(name="int-metadata", le=10, ge=5), - {"int-metadata": 7}, - ), - ], -) -def test_integer_metadata_property( - schema_kwargs: Dict[str, Any], - server_payload: Dict[str, Any], - metadata_filter: IntegerMetadataFilter, - metadata_property_to_validate: Dict[str, str], -) -> None: - metadata_property = IntegerMetadataProperty(**schema_kwargs) - assert metadata_property.type == MetadataPropertyTypes.integer - assert metadata_property.server_settings == server_payload["settings"] - assert metadata_property.to_server_payload() == server_payload - - metadata_property._validate_filter(metadata_filter=metadata_filter) # ValidationError not raised - - metadata_field, metadata_validator = metadata_property._pydantic_field_with_validator - MetadataModel = create_model("MetadataModel", **metadata_field, __validators__=metadata_validator) - assert MetadataModel(**metadata_property_to_validate) - - -@pytest.mark.parametrize( - "schema_kwargs, exception_cls, exception_message", - [ - ({"name": "a b"}, ValidationError, "name\n string does not match regex"), - ( - {"name": "int-metadata-property", "min": 5, "max": 5}, - ValidationError, - "1 validation error for IntegerMetadataProperty\n__root__\n `IntegerMetadataProperty` with name=int-metadata-property cannot have `max` less or equal than `min`", - ), - ( - {"name": "int-metadata-property", "min": 6, "max": 6}, - ValidationError, - "1 validation error for IntegerMetadataProperty\n__root__\n `IntegerMetadataProperty` with name=int-metadata-property cannot have `max` less or equal than `min`", - ), - ( - {"name": "int-metadata-property", "extra-arg": 5}, - ValidationError, - "1 validation error for IntegerMetadataProperty\nextra-arg\n extra fields not permitted \(type=value_error.extra\)", - ), - ], -) -def test_integer_metadata_property_errors( - schema_kwargs: Dict[str, Any], exception_cls: Any, exception_message: str -) -> None: - with pytest.raises(exception_cls, match=exception_message): - IntegerMetadataProperty(**schema_kwargs) - - -@pytest.mark.parametrize( - "schema_kwargs, server_payload, metadata_filter, metadata_property_to_validate", - [ - ( - {"name": "float-metadata", "title": "alt-title"}, - { - "name": "float-metadata", - "title": "alt-title", - "visible_for_annotators": True, - "settings": {"type": "float"}, - }, - FloatMetadataFilter(name="float-metadata", le=10.0, ge=5.0), - {"float-metadata": 7.5}, - ), - ( - { - "name": "float-metadata", - "visible_for_annotators": False, - "max": 5.0, - }, - { - "name": "float-metadata", - "title": "float-metadata", - "visible_for_annotators": False, - "settings": {"type": "float", "max": 5.0}, - }, - FloatMetadataFilter(name="float-metadata", le=5.0, ge=0.0), - {"float-metadata": 2.5}, - ), - ( - {"name": "float-metadata", "min": 5.0}, - { - "name": "float-metadata", - "title": "float-metadata", - "visible_for_annotators": True, - "settings": {"type": "float", "min": 5.0}, - }, - FloatMetadataFilter(name="float-metadata", le=10.0, ge=5.0), - {"float-metadata": 7.5}, - ), - ( - {"name": "float-metadata", "min": 5.0, "max": 10.0}, - { - "name": "float-metadata", - "title": "float-metadata", - "visible_for_annotators": True, - "settings": {"type": "float", "min": 5.0, "max": 10.0}, - }, - FloatMetadataFilter(name="float-metadata", le=10.0, ge=5.0), - {"float-metadata": 7.5}, - ), - ], -) -def test_float_metadata_property( - schema_kwargs: Dict[str, Any], - server_payload: Dict[str, Any], - metadata_filter: FloatMetadataFilter, - metadata_property_to_validate: Dict[str, str], -) -> None: - metadata_property = FloatMetadataProperty(**schema_kwargs) - assert metadata_property.type == MetadataPropertyTypes.float - assert metadata_property.server_settings == server_payload["settings"] - assert metadata_property.to_server_payload() == server_payload - - metadata_property._validate_filter(metadata_filter=metadata_filter) # ValidationError not raised - - metadata_field, metadata_validator = metadata_property._pydantic_field_with_validator - MetadataModel = create_model("MetadataModel", **metadata_field, __validators__=metadata_validator) - assert MetadataModel(**metadata_property_to_validate) - - -@pytest.mark.parametrize( - "schema_kwargs, exception_cls, exception_message", - [ - ({"name": "a b"}, ValidationError, "name\n string does not match regex"), - ( - {"name": "float-metadata-property", "min": 5.0, "max": 5.0}, - ValidationError, - "1 validation error for FloatMetadataProperty\n__root__\n `FloatMetadataProperty` with name=float-metadata-property cannot have `max` less or equal than `min`", - ), - ( - {"name": "float-metadata-property", "min": 6.0, "max": 5.0}, - ValidationError, - "1 validation error for FloatMetadataProperty\n__root__\n `FloatMetadataProperty` with name=float-metadata-property cannot have `max` less or equal than `min`", - ), - ( - {"name": "float-metadata-property", "extra-arg": 5.0}, - ValidationError, - "1 validation error for FloatMetadataProperty\nextra-arg\n extra fields not permitted \(type=value_error.extra\)", - ), - ], -) -def test_float_metadata_property_errors( - schema_kwargs: Dict[str, Any], exception_cls: Any, exception_message: str -) -> None: - with pytest.raises(exception_cls, match=exception_message): - FloatMetadataProperty(**schema_kwargs) - - -@pytest.mark.parametrize( - "schema_kwargs, query_string", - [ - ({"name": "name", "values": ["a"]}, "name:a"), - ({"name": "name-with-hyphen", "values": ["a", "b"]}, "name-with-hyphen:a,b"), - ({"name": "name_with_underscore", "values": ["a", "b", "c"]}, "name_with_underscore:a,b,c"), - ], -) -def test_terms_metadata_filter(schema_kwargs: Dict[str, Any], query_string: str) -> None: - metadata_filter = TermsMetadataFilter(**schema_kwargs) - assert metadata_filter.type == MetadataPropertyTypes.terms - assert metadata_filter.query_string == query_string - - -@pytest.mark.parametrize( - "schema_kwargs, exception_cls, exception_message", - [ - ({"name": "a b"}, ValidationError, "name\n string does not match regex"), - ( - {"name": "terms-metadata-filter", "values": []}, - ValidationError, - "1 validation error for TermsMetadataFilter\nvalues\n ensure this value has at least 1 items", - ), - ( - {"name": "terms-metadata-filter", "values": ["a", "a"]}, - ValidationError, - "1 validation error for TermsMetadataFilter\nvalues\n `TermsMetadataFilter` with name=terms-metadata-filter cannot have repeated `values`", - ), - ], -) -def test_terms_metadata_filter_errors( - schema_kwargs: Dict[str, Any], exception_cls: Any, exception_message: str -) -> None: - with pytest.raises(exception_cls, match=exception_message): - TermsMetadataFilter(**schema_kwargs) - - -@pytest.mark.parametrize( - "schema_kwargs, query_string", - [ - ({"name": "name", "le": 5}, 'name:{"le": 5}'), - ({"name": "name", "ge": 5}, 'name:{"ge": 5}'), - ({"name": "name", "le": 5, "ge": 1}, 'name:{"le": 5, "ge": 1}'), - ({"name": "name", "le": 5, "ge": 5}, 'name:{"le": 5, "ge": 5}'), - ], -) -def test_integer_metadata_filter(schema_kwargs: Dict[str, Any], query_string: str) -> None: - metadata_filter = IntegerMetadataFilter(**schema_kwargs) - assert metadata_filter.type == MetadataPropertyTypes.integer - assert metadata_filter.query_string == query_string - - -@pytest.mark.parametrize( - "schema_kwargs, exception_cls, exception_message", - [ - ({"name": "a b"}, ValidationError, "name\n string does not match regex"), - ( - {"name": "int-metadata-filter"}, - ValueError, - "1 validation error for IntegerMetadataFilter\n__root__\n `IntegerMetadataFilter` with name=int-metadata-filter must have at least one of `le` or `ge`", - ), - ( - {"name": "int-metadata-filter", "le": 5, "ge": 6}, - ValidationError, - "1 validation error for IntegerMetadataFilter\n__root__\n `IntegerMetadataFilter` with name=int-metadata-filter cannot have `le` less than `ge`", - ), - ], -) -def test_integer_metadata_filter_errors( - schema_kwargs: Dict[str, Any], exception_cls: Any, exception_message: str -) -> None: - with pytest.raises(exception_cls, match=exception_message): - IntegerMetadataFilter(**schema_kwargs) - - -@pytest.mark.parametrize( - "schema_kwargs, query_string", - [ - ({"name": "name", "le": 5.0}, 'name:{"le": 5.0}'), - ({"name": "name", "ge": 5.0}, 'name:{"ge": 5.0}'), - ({"name": "name", "ge": 1.0, "le": 5.0}, 'name:{"le": 5.0, "ge": 1.0}'), - ({"name": "name", "le": 5.0, "ge": 5.0}, 'name:{"le": 5.0, "ge": 5.0}'), - ], -) -def test_float_metadata_filter(schema_kwargs: Dict[str, Any], query_string: str) -> None: - metadata_filter = FloatMetadataFilter(**schema_kwargs) - assert metadata_filter.type == MetadataPropertyTypes.float - assert metadata_filter.query_string == query_string - - -@pytest.mark.parametrize( - "schema_kwargs, exception_cls, exception_message", - [ - ({"name": "a b"}, ValidationError, "name\n string does not match regex"), - ( - {"name": "float-metadata-filter"}, - ValueError, - "1 validation error for FloatMetadataFilter\n__root__\n `FloatMetadataFilter` with name=float-metadata-filter must have at least one of `le` or `ge`", - ), - ( - {"name": "float-metadata-filter", "le": 5.0, "ge": 6.0}, - ValidationError, - "1 validation error for FloatMetadataFilter\n__root__\n `FloatMetadataFilter` with name=float-metadata-filter cannot have `le` less than `ge`", - ), - ], -) -def test_float_metadata_filter_errors( - schema_kwargs: Dict[str, Any], exception_cls: Any, exception_message: str -) -> None: - with pytest.raises(exception_cls, match=exception_message): - FloatMetadataFilter(**schema_kwargs) diff --git a/argilla-v1/tests/unit/feedback/schemas/test_questions.py b/argilla-v1/tests/unit/feedback/schemas/test_questions.py deleted file mode 100644 index ddab700c4..000000000 --- a/argilla-v1/tests/unit/feedback/schemas/test_questions.py +++ /dev/null @@ -1,638 +0,0 @@ -# Copyright 2021-present, the Recognai S.L. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any, Dict - -import pytest -from argilla_v1.client.feedback.schemas.enums import LabelsOrder, QuestionTypes -from argilla_v1.client.feedback.schemas.questions import ( - LabelQuestion, - MultiLabelQuestion, - RankingQuestion, - RatingQuestion, - SpanLabelOption, - SpanQuestion, - TextQuestion, - _LabelQuestion, -) - -from tests.pydantic_v1 import ValidationError - - -@pytest.mark.parametrize( - "schema_kwargs, server_payload", - [ - ( - {"name": "a", "required": True, "use_markdown": True}, - { - "name": "a", - "title": "A", - "description": None, - "required": True, - "settings": {"type": "text", "use_markdown": True, "use_table": False}, - }, - ), - ( - {"name": "a", "title": "B", "description": "b", "required": False, "use_markdown": False}, - { - "name": "a", - "title": "B", - "description": "b", - "required": False, - "settings": {"type": "text", "use_markdown": False, "use_table": False}, - }, - ), - ( - {"name": "a", "required": True, "use_table": True}, - { - "name": "a", - "title": "A", - "description": None, - "required": True, - "settings": {"type": "text", "use_markdown": False, "use_table": True}, - }, - ), - ], -) -def test_text_question(schema_kwargs: Dict[str, Any], server_payload: Dict[str, Any]) -> None: - text_question = TextQuestion(**schema_kwargs) - assert text_question.type == QuestionTypes.text - assert text_question.server_settings == server_payload["settings"] - assert text_question.to_server_payload() == server_payload - - -@pytest.mark.parametrize( - "schema_kwargs, server_payload", - [ - ( - {"name": "a", "values": [8, 9, 10]}, - { - "name": "a", - "title": "A", - "description": None, - "required": True, - "settings": {"type": "rating", "options": [{"value": 8}, {"value": 9}, {"value": 10}]}, - }, - ), - ( - {"name": "a", "title": "A", "description": "a", "required": False, "values": [0, 1, 2, 3]}, - { - "name": "a", - "title": "A", - "description": "a", - "required": False, - "settings": {"type": "rating", "options": [{"value": 0}, {"value": 1}, {"value": 2}, {"value": 3}]}, - }, - ), - ], -) -def test_rating_question(schema_kwargs: Dict[str, Any], server_payload: Dict[str, Any]) -> None: - rating_question = RatingQuestion(**schema_kwargs) - assert rating_question.type == QuestionTypes.rating - assert rating_question.server_settings == server_payload["settings"] - assert rating_question.to_server_payload() == server_payload - - -@pytest.mark.parametrize( - "schema_kwargs, exception_cls, exception_message", - [ - ({"name": "a", "values": ["a", "b"]}, ValidationError, "value is not a valid integer"), - ({"name": "a", "values": [1, 1, 1]}, ValidationError, "the list has duplicated items"), - ({"name": "a", "values": [1]}, ValidationError, "ensure this value has at least 2 items"), - ({"name": "a", "values": [-1, 0, 1]}, ValidationError, "ensure this value is greater than or equal to 0"), - ({"name": "a", "values": [1, 11]}, ValidationError, "ensure this value is less than or equal to 10"), - ], -) -def test_rating_question_errors(schema_kwargs: Dict[str, Any], exception_cls: Any, exception_message: str) -> None: - with pytest.raises(exception_cls, match=exception_message): - RatingQuestion(**schema_kwargs) - - -@pytest.mark.parametrize( - "schema_kwargs, exception_cls, exception_message", - [ - ( - {"name": "a", "labels": ["a", "b"], "visible_labels": 2}, - ValidationError, - "ensure this value is greater than or equal to 3", - ), - ({"name": "a", "labels": ["a", "a"]}, ValidationError, "the list has duplicated items"), - ({"name": "a", "labels": "a"}, ValidationError, r"(value is not a valid list)|(value is not a valid dict)"), - ({"name": "a", "labels": {"a": "a"}}, ValidationError, "ensure this dict has at least 2 items"), - ({"name": "a", "labels": {"a": "a", "b": "a"}}, ValidationError, "ensure this dict has unique values"), - ], -) -def test_label_question_errors(schema_kwargs: Dict[str, Any], exception_cls: Any, exception_message: str) -> None: - with pytest.raises(exception_cls, match=exception_message): - _LabelQuestion(**schema_kwargs, type="label_selection") - - -@pytest.mark.parametrize( - "schema_kwargs, warning_cls, warning_message", - [ - ( - {"name": "a", "labels": ["a", "b", "c"], "visible_labels": 4}, - UserWarning, - "\`visible_labels=4\` is greater than the total number of labels \(3\), so it will be set to \`3\`.", - ), - ( - {"name": "a", "labels": ["a", "b"], "visible_labels": 3}, - UserWarning, - "\`labels=\['a', 'b'\]\` has less than 3 labels, so \`visible_labels\` will be set to \`None\`, which means that all the labels will be visible.", - ), - ( - {"name": "a", "labels": list(range(100))}, - UserWarning, - "Since \`visible_labels\` has not been provided and the total number of labels is greater than 20, \`visible_labels\` will be set to \`20\`.", - ), - ], -) -def test_label_question_warnings(schema_kwargs: Dict[str, Any], warning_cls: Warning, warning_message: str) -> None: - with pytest.warns(warning_cls, match=warning_message): - _LabelQuestion(**schema_kwargs, type="label_selection") - - -@pytest.mark.parametrize( - "schema_kwargs, server_payload", - [ - ( - {"name": "a", "labels": ["a", "b"]}, - { - "name": "a", - "title": "A", - "description": None, - "required": True, - "settings": { - "type": "label_selection", - "options": [{"value": "a", "text": "a"}, {"value": "b", "text": "b"}], - "visible_options": None, - }, - }, - ), - ( - {"name": "a", "labels": {"a": "A", "b": "B"}}, - { - "name": "a", - "title": "A", - "description": None, - "required": True, - "settings": { - "type": "label_selection", - "options": [{"value": "a", "text": "A"}, {"value": "b", "text": "B"}], - "visible_options": None, - }, - }, - ), - ( - {"name": "a", "labels": ["a", "b"], "visible_labels": 3}, - { - "name": "a", - "title": "A", - "description": None, - "required": True, - "settings": { - "type": "label_selection", - "options": [{"value": "a", "text": "a"}, {"value": "b", "text": "b"}], - "visible_options": None, - }, - }, - ), - ( - {"name": "a", "labels": ["a", "b"], "visible_labels": None}, - { - "name": "a", - "title": "A", - "description": None, - "required": True, - "settings": { - "type": "label_selection", - "options": [{"value": "a", "text": "a"}, {"value": "b", "text": "b"}], - "visible_options": None, - }, - }, - ), - ( - {"name": "a", "labels": list(range(20))}, - { - "name": "a", - "title": "A", - "description": None, - "required": True, - "settings": { - "type": "label_selection", - "options": [{"value": str(n), "text": str(n)} for n in list(range(20))], - "visible_options": None, - }, - }, - ), - ( - {"name": "a", "labels": list(range(21))}, - { - "name": "a", - "title": "A", - "description": None, - "required": True, - "settings": { - "type": "label_selection", - "options": [{"value": str(n), "text": str(n)} for n in list(range(21))], - "visible_options": 20, - }, - }, - ), - ( - {"name": "a", "labels": list(range(2)), "visible_labels": None}, - { - "name": "a", - "title": "A", - "description": None, - "required": True, - "settings": { - "type": "label_selection", - "options": [{"value": str(n), "text": str(n)} for n in list(range(2))], - "visible_options": None, - }, - }, - ), - ( - {"name": "a", "labels": list(range(2)), "visible_labels": 3}, - { - "name": "a", - "title": "A", - "description": None, - "required": True, - "settings": { - "type": "label_selection", - "options": [{"value": str(n), "text": str(n)} for n in list(range(2))], - "visible_options": None, - }, - }, - ), - ], -) -def test_label_question(schema_kwargs: Dict[str, Any], server_payload: Dict[str, Any]) -> None: - label_question = LabelQuestion(**schema_kwargs) - assert label_question.type == QuestionTypes.label_selection - assert label_question.server_settings == server_payload["settings"] - assert label_question.to_server_payload() == server_payload - - -@pytest.mark.parametrize( - "schema_kwargs, server_payload", - [ - ( - {"name": "a", "labels": ["a", "b"]}, - { - "name": "a", - "title": "A", - "description": None, - "required": True, - "settings": { - "type": "multi_label_selection", - "options": [{"value": "a", "text": "a"}, {"value": "b", "text": "b"}], - "visible_options": None, - "options_order": LabelsOrder.natural, - }, - }, - ), - ( - {"name": "a", "labels": {"a": "A", "b": "B"}, "labels_order": LabelsOrder.suggestion}, - { - "name": "a", - "title": "A", - "description": None, - "required": True, - "settings": { - "type": "multi_label_selection", - "options": [{"value": "a", "text": "A"}, {"value": "b", "text": "B"}], - "visible_options": None, - "options_order": LabelsOrder.suggestion, - }, - }, - ), - ( - {"name": "a", "labels": ["a", "b"], "visible_labels": 3}, - { - "name": "a", - "title": "A", - "description": None, - "required": True, - "settings": { - "type": "multi_label_selection", - "options": [{"value": "a", "text": "a"}, {"value": "b", "text": "b"}], - "visible_options": None, - "options_order": LabelsOrder.natural, - }, - }, - ), - ( - {"name": "a", "labels": ["a", "b"], "visible_labels": None}, - { - "name": "a", - "title": "A", - "description": None, - "required": True, - "settings": { - "type": "multi_label_selection", - "options": [{"value": "a", "text": "a"}, {"value": "b", "text": "b"}], - "visible_options": None, - "options_order": LabelsOrder.natural, - }, - }, - ), - ( - {"name": "a", "labels": list(range(20))}, - { - "name": "a", - "title": "A", - "description": None, - "required": True, - "settings": { - "type": "multi_label_selection", - "options": [{"value": str(n), "text": str(n)} for n in list(range(20))], - "visible_options": None, - "options_order": LabelsOrder.natural, - }, - }, - ), - ( - {"name": "a", "labels": list(range(21))}, - { - "name": "a", - "title": "A", - "description": None, - "required": True, - "settings": { - "type": "multi_label_selection", - "options": [{"value": str(n), "text": str(n)} for n in list(range(21))], - "visible_options": 20, - "options_order": LabelsOrder.natural, - }, - }, - ), - ( - {"name": "a", "labels": list(range(2)), "visible_labels": None}, - { - "name": "a", - "title": "A", - "description": None, - "required": True, - "settings": { - "type": "multi_label_selection", - "options": [{"value": str(n), "text": str(n)} for n in list(range(2))], - "visible_options": None, - "options_order": LabelsOrder.natural, - }, - }, - ), - ( - {"name": "a", "labels": list(range(2)), "visible_labels": 3}, - { - "name": "a", - "title": "A", - "description": None, - "required": True, - "settings": { - "type": "multi_label_selection", - "options": [{"value": str(n), "text": str(n)} for n in list(range(2))], - "visible_options": None, - "options_order": LabelsOrder.natural, - }, - }, - ), - ], -) -def test_multi_label_question(schema_kwargs: Dict[str, Any], server_payload: Dict[str, Any]) -> None: - label_question = MultiLabelQuestion(**schema_kwargs) - assert label_question.type == QuestionTypes.multi_label_selection - assert label_question.server_settings == server_payload["settings"] - assert label_question.to_server_payload() == server_payload - - -@pytest.mark.parametrize( - "schema_kwargs, server_payload", - [ - ( - {"name": "a", "values": ["a", "b"]}, - { - "name": "a", - "title": "A", - "description": None, - "required": True, - "settings": {"type": "ranking", "options": [{"value": "a", "text": "a"}, {"value": "b", "text": "b"}]}, - }, - ), - ( - {"name": "a", "values": {"a": "A", "b": "B"}}, - { - "name": "a", - "title": "A", - "description": None, - "required": True, - "settings": {"type": "ranking", "options": [{"value": "a", "text": "A"}, {"value": "b", "text": "B"}]}, - }, - ), - ], -) -def test_ranking_question(schema_kwargs: Dict[str, Any], server_payload: Dict[str, Any]) -> None: - ranking_question = RankingQuestion(**schema_kwargs) - assert ranking_question.type == QuestionTypes.ranking - assert ranking_question.server_settings == server_payload["settings"] - assert ranking_question.to_server_payload() == server_payload - - -@pytest.mark.parametrize( - "schema_kwargs, exception_cls, exception_message", - [ - ({"name": "a", "values": [1, 1]}, ValidationError, "the list has duplicated items"), - ({"name": "a", "values": ["a"]}, ValidationError, "ensure this value has at least 2 items"), - ({"name": "a", "values": {"a": "a"}}, ValidationError, "ensure this dict has at least 2 items"), - ({"name": "a", "values": {1: "a", 2: "a"}}, ValidationError, "ensure this dict has unique values"), - ], -) -def test_ranking_question_errors(schema_kwargs: Dict[str, Any], exception_cls: Any, exception_message: str) -> None: - with pytest.raises(exception_cls, match=exception_message): - RankingQuestion(**schema_kwargs) - - -def test_span_question() -> None: - question = SpanQuestion( - name="question", - field="field", - title="Question", - description="Description", - required=True, - allow_overlapping=True, - labels=["a", "b"], - ) - - assert question.type == QuestionTypes.span - assert question.server_settings == { - "type": "span", - "field": "field", - "visible_options": None, - "allow_overlapping": True, - "options": [{"value": "a", "text": "a", "description": None}, {"value": "b", "text": "b", "description": None}], - } - - -def test_span_question_with_labels_dict() -> None: - question = SpanQuestion( - name="question", - field="field", - title="Question", - description="Description", - labels={"a": "A text", "b": "B text"}, - ) - - assert question.type == QuestionTypes.span - assert question.server_settings == { - "type": "span", - "field": "field", - "visible_options": None, - "allow_overlapping": False, - "options": [ - {"value": "a", "text": "A text", "description": None}, - {"value": "b", "text": "B text", "description": None}, - ], - } - - -def test_span_question_with_visible_labels() -> None: - question = SpanQuestion( - name="question", - field="field", - title="Question", - description="Description", - labels=["a", "b", "c", "d"], - visible_labels=3, - ) - - assert question.type == QuestionTypes.span - assert question.server_settings == { - "type": "span", - "field": "field", - "visible_options": 3, - "allow_overlapping": False, - "options": [ - {"value": "a", "text": "a", "description": None}, - {"value": "b", "text": "b", "description": None}, - {"value": "c", "text": "c", "description": None}, - {"value": "d", "text": "d", "description": None}, - ], - } - - -def test_span_question_with_visible_labels_default_value(): - question = SpanQuestion( - name="question", - field="field", - title="Question", - description="Description", - labels=list(range(21)), - ) - - assert question.visible_labels == 20 - - -def test_span_question_with_default_visible_label_when_labels_is_less_than_20(): - with pytest.warns(UserWarning, match=""): - question = SpanQuestion( - name="question", - field="field", - title="Question", - description="Description", - labels=list(range(19)), - ) - - assert question.visible_labels == 19 - - -def test_span_question_when_visible_labels_is_greater_than_total_labels(): - with pytest.warns( - UserWarning, - match="`visible_labels=4` is greater than the total number of labels \(3\)", - ): - question = SpanQuestion( - name="question", - field="field", - title="Question", - description="Description", - labels=["a", "b", "c"], - visible_labels=4, - ) - - assert question.visible_labels == 3 - - -def test_span_question_with_visible_labels_less_than_total_labels(): - with pytest.warns( - UserWarning, match="Since `labels` has less than 3 labels, `visible_labels` will be set to `None`." - ): - question = SpanQuestion( - name="question", - field="field", - title="Question", - description="Description", - labels=["a", "b"], - visible_labels=3, - ) - - assert question.visible_labels is None - - -def test_span_question_with_visible_labels_less_than_min_value(): - with pytest.raises(ValidationError, match="ensure this value is greater than or equal to 3"): - SpanQuestion( - name="question", - field="field", - title="Question", - description="Description", - labels=["a", "b"], - visible_labels=2, - ) - - -def test_span_questions_with_default_visible_labels_and_less_labels_than_default(): - with pytest.warns(UserWarning, match="visible_labels=20` is greater than the total number of labels"): - question = SpanQuestion( - name="question", - field="field", - title="Question", - description="Description", - labels=list(range(10)), - ) - - assert question.visible_labels == 10 - - -def test_span_question_with_no_labels() -> None: - with pytest.raises(ValidationError, match="At least one label must be provided"): - SpanQuestion( - name="question", - field="field", - title="Question", - description="Description", - labels=[], - ) - - -def test_span_question_with_duplicated_labels() -> None: - with pytest.raises(ValidationError, match="the list has duplicated items"): - SpanQuestion( - name="question", - title="Question", - field="field", - description="Description", - labels=[SpanLabelOption(value="a", text="A text"), SpanLabelOption(value="a", text="Text for A")], - ) diff --git a/argilla-v1/tests/unit/feedback/schemas/test_records.py b/argilla-v1/tests/unit/feedback/schemas/test_records.py deleted file mode 100644 index 00d87e2be..000000000 --- a/argilla-v1/tests/unit/feedback/schemas/test_records.py +++ /dev/null @@ -1,253 +0,0 @@ -# Copyright 2021-present, the Recognai S.L. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import warnings -from typing import Any, Dict, List, Optional, Type, Union - -import pytest -from argilla_v1.client.feedback.schemas.records import ( - FeedbackRecord, - RankingValueSchema, - ResponseSchema, - SortBy, - SuggestionSchema, - ValueSchema, -) - -from tests.pydantic_v1 import ValidationError - - -@pytest.mark.parametrize( - "schema_kwargs", - [ - { - "fields": {"text": "This is the first record", "label": "positive"}, - "metadata": {"first": True, "nested": {"more": "stuff"}}, - "responses": [ - { - "values": {"question-1": {"value": "This is the first answer"}, "question-2": {"value": 5}}, - "status": "submitted", - }, - { - "values": {"question-1": {"value": "This is the first answer"}, "question-2": {"value": 5}}, - "status": "draft", - }, - { - "values": {}, - "status": "discarded", - }, - ], - "suggestions": [ - { - "question_name": "question-1", - "type": "model", - "score": 0.9, - "value": "This is the first suggestion", - "agent": "agent-1", - } - ], - "vectors": { - "vector-1": [1.0, 2.0, 3.0], - "vector-2": [1.0, 2.0, 3.0, 4.0], - }, - "external_id": "entry-1", - }, - { - "fields": {"text": "This is the first record", "label": "positive"}, - }, - { - "fields": {"text": "This is the first record", "label": "positive"}, - "metadata": {"first": True, "nested": {"more": "stuff"}}, - }, - { - "fields": {"text": "This is the first record", "label": "positive"}, - "responses": [ - { - "values": {"question-1": {"value": "This is the first answer"}, "question-2": {"value": 5}}, - "status": "submitted", - } - ], - }, - { - "fields": {"text": "This is the first record", "label": "positive"}, - "suggestions": [ - { - "question_name": "question-1", - "type": "model", - "score": 0.9, - "value": "This is the first suggestion", - "agent": "agent-1", - } - ], - }, - ], -) -def test_feedback_record(schema_kwargs: Dict[str, Any]) -> None: - assert FeedbackRecord(**schema_kwargs) - - -# TODO(@alvaro): Check why there are missing tests cases checking feedback errors. - - -@pytest.mark.parametrize( - "schema_kwargs, suggestions, expected_warning, warning_match", - [ - ( - { - "fields": {"text": "This is the first record", "label": "positive"}, - }, - { - "question_name": "question-1", - "type": "model", - "score": 0.9, - "value": "This is the first suggestion", - "agent": "agent-1", - }, - None, - None, - ), - ( - { - "fields": {"text": "This is the first record", "label": "positive"}, - "suggestions": [{"question_name": "question-1", "value": "This is the first suggestion"}], - }, - { - "question_name": "question-1", - "type": "model", - "score": 0.9, - "value": "This is the second suggestion", - "agent": "agent-2", - }, - UserWarning, - "A suggestion for question `question-1` has already been provided", - ), - ], -) -def test_feedback_record_update( - schema_kwargs: Dict[str, Any], - suggestions: Union[SuggestionSchema, List[SuggestionSchema], Dict[str, Any], List[Dict[str, Any]]], - expected_warning: Optional[Type[Warning]], - warning_match: Optional[str], -) -> None: - record = FeedbackRecord(**schema_kwargs) - - if expected_warning is None: - with warnings.catch_warnings(record=True) as record_: - record.update(suggestions) - assert len(record_) == 0 - else: - with pytest.warns(expected_warning, match=warning_match): - record.update(suggestions) - - -@pytest.mark.parametrize( - "schema_kwargs", - [ - {"value": 1}, - {"value": "This is a value"}, - {"value": ["This is a value"]}, - {"value": [{"rank": 1, "value": "This is a value"}]}, - ], -) -def test_value_schema(schema_kwargs: Dict[str, Any]) -> None: - assert ValueSchema(**schema_kwargs) - - -@pytest.mark.parametrize( - "schema_kwargs", - [ - {"value": "question-1", "rank": 1}, - {"value": "question-1", "rank": None}, - ], -) -def test_ranking_value_schema(schema_kwargs: Dict[str, Any]) -> None: - assert RankingValueSchema(**schema_kwargs) - - -@pytest.mark.parametrize( - "schema_kwargs, expected_exception, expected_exception_message", - [ - ( - {"value": 1, "rank": 1}, - ValidationError, - "str type expected", - ), - ( - {"value": "question-1", "rank": "string"}, - ValidationError, - "value is not a valid integer", - ), - ( - {"value": "question-1", "rank": 0}, - ValidationError, - "ensure this value is greater than or equal to 1", - ), - ], -) -def test_ranking_value_schema_errors( - schema_kwargs: Dict[str, Any], expected_exception: Exception, expected_exception_message: str -) -> None: - with pytest.raises(expected_exception, match=expected_exception_message): - RankingValueSchema(**schema_kwargs) - - -@pytest.mark.parametrize( - "schema_kwargs", - [ - {"user_id": "00000000-0000-0000-0000-000000000000", "values": {"question-1": {"value": 1}}}, - {"user_id": "00000000-0000-0000-0000-000000000000", "values": {"question-1": {"value": "This is a value"}}}, - {"user_id": "00000000-0000-0000-0000-000000000000", "values": {"question-1": {"value": ["This is a value"]}}}, - { - "user_id": "00000000-0000-0000-0000-000000000000", - "values": {"question-1": {"value": [{"rank": 1, "value": "This is a value"}]}}, - }, - {"values": {"question-1": {"value": 1}}}, - {"values": {"question-1": {"value": "This is a value"}}, "status": "submitted"}, - {"values": None, "status": "discarded"}, - ], -) -def test_response_schema(schema_kwargs: Dict[str, Any]) -> None: - assert ResponseSchema(**schema_kwargs) - - -@pytest.mark.parametrize( - "schema_kwargs", - [ - {"question_name": "question-1", "value": 1}, - {"question_name": "question-1", "value": "This is a value"}, - {"question_name": "question-1", "value": ["This is a value"]}, - {"question_name": "question-1", "value": [{"rank": 1, "value": "This is a value"}]}, - { - "question_name": "question-1", - "value": [{"rank": 1, "value": "This is a value"}, {"rank": 2, "value": "This is a value"}], - "score": 0.9, - "type": "model", - "agent": "agent-1", - }, - ], -) -def test_suggestion_schema(schema_kwargs: Dict[str, Any]) -> None: - assert SuggestionSchema(**schema_kwargs) - - -@pytest.mark.parametrize( - "wrong_args", - [ - dict(field="wrogn_name"), - dict(field="metadata.field", order="wrong_order"), - dict(dict="ascc", order="asc"), - ], -) -def test_sort_by_with_wrong_fields(wrong_args: Dict[str, Any]) -> None: - with pytest.raises(ValidationError): - SortBy(**wrong_args) diff --git a/argilla-v1/tests/unit/feedback/schemas/test_responses.py b/argilla-v1/tests/unit/feedback/schemas/test_responses.py deleted file mode 100644 index 41fc2f81f..000000000 --- a/argilla-v1/tests/unit/feedback/schemas/test_responses.py +++ /dev/null @@ -1,69 +0,0 @@ -# Copyright 2021-present, the Recognai S.L. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pytest -from argilla_v1.feedback import ( - ResponseSchema, - ResponseStatus, - SpanValueSchema, - TextQuestion, -) - - -def test_create_span_response_wrong_limits(): - with pytest.raises(ValueError, match="The end of the span must be greater than the start."): - SpanValueSchema(start=10, end=8, label="test") - - -def test_create_response(): - question = TextQuestion(name="text") - response = ResponseSchema(status="draft", values=[question.response("Value for text")]) - - assert response.status == ResponseStatus.draft - assert question.name in response.values - assert response.values[question.name].value == "Value for text" - - -def test_create_responses_with_multiple_questions(): - question1 = TextQuestion(name="text") - question2 = TextQuestion(name="text2") - response = ResponseSchema( - status="draft", - values=[ - question1.response("Value for text"), - question2.response("Value for text2"), - ], - ) - - assert response.status == ResponseStatus.draft - assert question1.name in response.values - assert response.values[question1.name].value == "Value for text" - assert question2.name in response.values - assert response.values[question2.name].value == "Value for text2" - - -def test_create_response_with_wrong_value(): - with pytest.raises(ValueError, match="Value 10 is not valid for question type text. Expected ."): - ResponseSchema(status="draft", values=[TextQuestion(name="text").response(10)]) - - -def test_response_to_server_payload_with_string_status(): - assert ResponseSchema(status="draft").to_server_payload() == {"user_id": None, "status": "draft"} - - -def test_response_to_server_payload_with_no_values(): - assert ResponseSchema().to_server_payload() == {"user_id": None, "status": "submitted"} - assert ResponseSchema(values=None).to_server_payload() == {"user_id": None, "status": "submitted", "values": None} - assert ResponseSchema(values=[]).to_server_payload() == {"user_id": None, "status": "submitted", "values": {}} - assert ResponseSchema(values={}).to_server_payload() == {"user_id": None, "status": "submitted", "values": {}} diff --git a/argilla-v1/tests/unit/feedback/schemas/test_suggestions.py b/argilla-v1/tests/unit/feedback/schemas/test_suggestions.py deleted file mode 100644 index 935e9463f..000000000 --- a/argilla-v1/tests/unit/feedback/schemas/test_suggestions.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright 2021-present, the Recognai S.L. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pytest -from argilla_v1.feedback import TextQuestion - - -def test_create_suggestion(): - question = TextQuestion(name="text") - - suggestion = question.suggestion("Value for text", agent="mock") - - assert suggestion.question_name == question.name - assert suggestion.agent == "mock" - - -def test_create_suggestion_with_wrong_value(): - with pytest.raises(ValueError, match="Value 10 is not valid for question type text. Expected ."): - TextQuestion(name="text").suggestion(value=10, agent="Mock") diff --git a/argilla-v1/tests/unit/feedback/schemas/test_utils.py b/argilla-v1/tests/unit/feedback/schemas/test_utils.py deleted file mode 100644 index 583cfe92a..000000000 --- a/argilla-v1/tests/unit/feedback/schemas/test_utils.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright 2021-present, the Recognai S.L. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from argilla_v1.client.feedback.schemas.utils import LabelMappingMixin - -from tests.pydantic_v1 import BaseModel, Field - - -def test_label_mapping_mixin() -> None: - class TestLabelMappingMixin(BaseModel, LabelMappingMixin): - server_settings: dict = Field(default_factory=dict) - - my_class = TestLabelMappingMixin(server_settings={"options": [{"value": "label1"}, {"value": "label2"}]}) - assert my_class.__all_labels__ == ["label1", "label2"] - assert my_class.__label2id__ == {"label1": 0, "label2": 1} - assert my_class.__id2label__ == {0: "label1", 1: "label2"} diff --git a/argilla-v1/tests/unit/feedback/schemas/test_vectors_settings.py b/argilla-v1/tests/unit/feedback/schemas/test_vectors_settings.py deleted file mode 100644 index 55be41799..000000000 --- a/argilla-v1/tests/unit/feedback/schemas/test_vectors_settings.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2021-present, the Recognai S.L. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/argilla-v1/tests/unit/feedback/test_config.py b/argilla-v1/tests/unit/feedback/test_config.py deleted file mode 100644 index c658ab955..000000000 --- a/argilla-v1/tests/unit/feedback/test_config.py +++ /dev/null @@ -1,853 +0,0 @@ -# Copyright 2021-present, the Recognai S.L. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -import re -from typing import TYPE_CHECKING, Any, Dict, List - -import pytest -from argilla_v1.client.feedback.config import DatasetConfig, DeprecatedDatasetConfig -from argilla_v1.client.feedback.schemas.enums import LabelsOrder -from argilla_v1.client.feedback.schemas.fields import FieldSchema -from argilla_v1.client.feedback.schemas.questions import QuestionSchema -from yaml import SafeLoader, load - -if TYPE_CHECKING: - from argilla_v1.client.feedback.schemas.types import ( - AllowedFieldTypes, - AllowedMetadataPropertyTypes, - AllowedQuestionTypes, - ) - - -def test_dataset_config_yaml( - feedback_dataset_fields: List["AllowedFieldTypes"], - feedback_dataset_questions: List["AllowedQuestionTypes"], - feedback_dataset_metadata_properties: List["AllowedMetadataPropertyTypes"], - feedback_dataset_guidelines: str, -) -> None: - config = DatasetConfig( - fields=feedback_dataset_fields, - questions=feedback_dataset_questions, - metadata_properties=feedback_dataset_metadata_properties, - guidelines=feedback_dataset_guidelines, - ) - assert isinstance(config, DatasetConfig) - assert config.fields == feedback_dataset_fields - assert config.questions == feedback_dataset_questions - assert config.metadata_properties == feedback_dataset_metadata_properties - assert config.guidelines == feedback_dataset_guidelines - - to_yaml_config = config.to_yaml() - assert isinstance(to_yaml_config, str) - assert all(f"name: {field.name}" in to_yaml_config for field in feedback_dataset_fields) - assert all(f"name: {question.name}" in to_yaml_config for question in feedback_dataset_questions) - assert all( - f"name: {metadata_property.name}" in to_yaml_config - for metadata_property in feedback_dataset_metadata_properties - ) - assert f"guidelines: {feedback_dataset_guidelines}" in to_yaml_config - - assert "!!python/object:uuid.UUID" not in to_yaml_config - - from_yaml_config = DatasetConfig.from_yaml(to_yaml_config) - assert isinstance(from_yaml_config, DatasetConfig) - assert from_yaml_config.fields == feedback_dataset_fields - assert from_yaml_config.questions == feedback_dataset_questions - assert from_yaml_config.metadata_properties == feedback_dataset_metadata_properties - assert from_yaml_config.guidelines == feedback_dataset_guidelines - - -@pytest.mark.usefixtures("feedback_dataset_fields", "feedback_dataset_questions", "feedback_dataset_guidelines") -def test_dataset_config_json_deprecated( - feedback_dataset_fields: List["AllowedFieldTypes"], - feedback_dataset_questions: List["AllowedQuestionTypes"], - feedback_dataset_guidelines: str, -) -> None: - config = DeprecatedDatasetConfig( - fields=feedback_dataset_fields, - questions=feedback_dataset_questions, - guidelines=feedback_dataset_guidelines, - ) - assert isinstance(config, DeprecatedDatasetConfig) - assert config.fields == feedback_dataset_fields - assert config.questions == feedback_dataset_questions - assert config.guidelines == feedback_dataset_guidelines - - with pytest.warns(DeprecationWarning, match="`DatasetConfig` can just be dumped to YAML"): - to_json_config = config.to_json() - assert isinstance(to_json_config, str) - assert all(f'"name": "{field.name}"' in to_json_config for field in feedback_dataset_fields) - assert all(f'"name": "{question.name}"' in to_json_config for question in feedback_dataset_questions) - assert f'"guidelines": "{feedback_dataset_guidelines}"' in to_json_config - - with pytest.warns(DeprecationWarning, match="`DatasetConfig` can just be loaded from YAML"): - from_json_config = config.from_json(to_json_config) - assert isinstance(from_json_config, DeprecatedDatasetConfig) - assert from_json_config.fields == feedback_dataset_fields - assert from_json_config.questions == feedback_dataset_questions - assert from_json_config.guidelines == feedback_dataset_guidelines - - -@pytest.mark.parametrize( - "argilla_version, outdated_config", - ( - ( - "1.8.0", - { - "fields": [ - { - "name": "field-1", - "title": "Field-1", - "required": True, - "settings": {"type": "text", "use_markdown": False, "use_table": False}, - "id": "14585f01-2c97-4756-92cc-cc7af17bc342", - "inserted_at": "2023-10-10T16:20:23", - "updated_at": "2023-10-10T16:20:23", - }, - { - "name": "field-2", - "title": "Field-2", - "required": False, - "settings": {"type": "text", "use_markdown": False, "use_table": False}, - "id": "aaaddedc-a273-478a-a27a-b0970b61a7ef", - "inserted_at": "2023-10-10T16:20:23", - "updated_at": "2023-10-10T16:20:23", - }, - ], - "questions": [ - { - "name": "question-1", - "title": "Question-1", - "description": None, - "required": True, - "settings": {"type": "text", "use_markdown": False, "use_table": False}, - "id": "421982f2-b1e6-4725-91d6-8e8b908a9b6b", - "inserted_at": "2023-10-10T16:20:23", - "updated_at": "2023-10-10T16:20:23", - }, - { - "name": "question-2", - "title": "Question-2", - "description": None, - "required": False, - "settings": { - "type": "rating", - "options": [{"value": 1}, {"value": 2}, {"value": 3}, {"value": 4}, {"value": 5}], - }, - "id": "95f38f77-f4dc-4b1a-9638-1d419359be36", - "inserted_at": "2023-10-10T16:20:24", - "updated_at": "2023-10-10T16:20:24", - }, - ], - "guidelines": "These are the guidelines", - }, - ), - ( - "1.9.0", - { - "fields": [ - { - "name": "field-1", - "title": "Field-1", - "required": True, - "settings": {"type": "text", "use_markdown": False, "use_table": False}, - "use_markdown": False, - "id": "71b3b494-a8fa-4fb7-98e5-d1d73a3f5f81", - "inserted_at": "2023-10-10T16:18:15", - "updated_at": "2023-10-10T16:18:15", - }, - { - "name": "field-2", - "title": "Field-2", - "required": False, - "settings": {"type": "text", "use_markdown": False, "use_table": False}, - "use_markdown": False, - "id": "29ae19df-5fc1-4f65-892f-5bc03df3066b", - "inserted_at": "2023-10-10T16:18:15", - "updated_at": "2023-10-10T16:18:15", - }, - ], - "questions": [ - { - "name": "question-1", - "title": "Question-1", - "description": None, - "required": True, - "settings": {"type": "text", "use_markdown": False, "use_table": False}, - "use_markdown": False, - "id": "98672849-651b-4c00-ab2a-7f29087c5b22", - "inserted_at": "2023-10-10T16:18:15", - "updated_at": "2023-10-10T16:18:15", - }, - { - "name": "question-2", - "title": "Question-2", - "description": None, - "required": False, - "settings": { - "type": "rating", - "options": [{"value": 1}, {"value": 2}, {"value": 3}, {"value": 4}, {"value": 5}], - }, - "id": "17ec3e6a-656a-47fa-b016-2f825a7db5a3", - "inserted_at": "2023-10-10T16:18:16", - "updated_at": "2023-10-10T16:18:16", - }, - { - "name": "question-3", - "title": "Question-3", - "description": None, - "required": False, - "settings": { - "type": "label_selection", - "options": [ - {"value": "label-1", "text": "label-1", "description": None}, - {"value": "label-2", "text": "label-2", "description": None}, - {"value": "label-3", "text": "label-3", "description": None}, - ], - "visible_options": 3, - }, - "visible_labels": 20, - "id": "857faadf-b65f-4a7f-9336-574c0581aef8", - "inserted_at": "2023-10-10T16:18:16", - "updated_at": "2023-10-10T16:18:16", - }, - { - "name": "question-4", - "title": "Question-4", - "description": None, - "required": False, - "settings": { - "type": "multi_label_selection", - "options": [ - {"value": "label-1", "text": "label-1", "description": None}, - {"value": "label-2", "text": "label-2", "description": None}, - {"value": "label-3", "text": "label-3", "description": None}, - ], - "visible_options": 3, - "options_order": LabelsOrder.natural, - }, - "visible_labels": 20, - "id": "a3a12c67-73d8-41b6-a697-f88be8f9386c", - "inserted_at": "2023-10-10T16:18:16", - "updated_at": "2023-10-10T16:18:16", - }, - ], - "guidelines": "These are the guidelines", - }, - ), - ( - "1.10.0", - { - "fields": [ - { - "name": "field-1", - "title": "Field-1", - "required": True, - "settings": {"type": "text", "use_markdown": False, "use_table": False}, - "use_markdown": False, - "id": "f2b70656-4d00-48e5-8309-a45bfd2bfb5a", - "inserted_at": "2023-10-10T16:17:05", - "updated_at": "2023-10-10T16:17:05", - }, - { - "name": "field-2", - "title": "Field-2", - "required": False, - "settings": {"type": "text", "use_markdown": False, "use_table": False}, - "use_markdown": False, - "id": "8082835f-446f-4ae5-9e0a-426232eb50b1", - "inserted_at": "2023-10-10T16:17:06", - "updated_at": "2023-10-10T16:17:06", - }, - ], - "questions": [ - { - "name": "question-1", - "title": "Question-1", - "description": None, - "required": True, - "settings": {"type": "text", "use_markdown": False, "use_table": False}, - "use_markdown": False, - "id": "d30e8ec1-9c96-4f8a-9cfe-2082738602ad", - "inserted_at": "2023-10-10T16:17:06", - "updated_at": "2023-10-10T16:17:06", - }, - { - "name": "question-2", - "title": "Question-2", - "description": None, - "required": False, - "settings": { - "type": "rating", - "options": [{"value": 1}, {"value": 2}, {"value": 3}, {"value": 4}, {"value": 5}], - }, - "id": "7502637f-ea92-41b7-ae6e-b338494b55dc", - "inserted_at": "2023-10-10T16:17:06", - "updated_at": "2023-10-10T16:17:06", - }, - { - "name": "question-3", - "title": "Question-3", - "description": None, - "required": False, - "settings": { - "type": "label_selection", - "options": [ - {"value": "label-1", "text": "label-1", "description": None}, - {"value": "label-2", "text": "label-2", "description": None}, - {"value": "label-3", "text": "label-3", "description": None}, - ], - "visible_options": 3, - }, - "visible_labels": 20, - "id": "43f7242e-a52e-4cd5-9823-82b04d1c38e6", - "inserted_at": "2023-10-10T16:17:06", - "updated_at": "2023-10-10T16:17:06", - }, - { - "name": "question-4", - "title": "Question-4", - "description": None, - "required": False, - "settings": { - "type": "multi_label_selection", - "options": [ - {"value": "label-1", "text": "label-1", "description": None}, - {"value": "label-2", "text": "label-2", "description": None}, - {"value": "label-3", "text": "label-3", "description": None}, - ], - "visible_options": 3, - "options_order": LabelsOrder.natural, - }, - "visible_labels": 20, - "id": "0fbcf59a-eef9-48d0-b50c-011b22a1b611", - "inserted_at": "2023-10-10T16:17:07", - "updated_at": "2023-10-10T16:17:07", - }, - ], - "guidelines": "These are the guidelines", - }, - ), - ( - "1.11.0", - { - "fields": [ - { - "name": "field-1", - "title": "Field-1", - "required": True, - "settings": {"type": "text", "use_markdown": False, "use_table": False}, - "use_markdown": False, - }, - { - "name": "field-2", - "title": "Field-2", - "required": False, - "settings": {"type": "text", "use_markdown": False, "use_table": False}, - "use_markdown": False, - }, - ], - "questions": [ - { - "name": "question-1", - "title": "Question-1", - "description": None, - "required": True, - "settings": {"type": "text", "use_markdown": False, "use_table": False}, - "use_markdown": False, - }, - { - "name": "question-2", - "title": "Question-2", - "description": None, - "required": False, - "settings": { - "type": "rating", - "options": [{"value": 1}, {"value": 2}, {"value": 3}, {"value": 4}, {"value": 5}], - }, - "values": [1, 2, 3, 4, 5], - }, - { - "name": "question-3", - "title": "Question-3", - "description": None, - "required": False, - "settings": { - "type": "label_selection", - "options": [ - {"value": "label-1", "text": "label-1"}, - {"value": "label-2", "text": "label-2"}, - {"value": "label-3", "text": "label-3"}, - ], - "visible_options": 3, - }, - "labels": ["label-1", "label-2", "label-3"], - "visible_labels": 3, - }, - { - "name": "question-4", - "title": "Question-4", - "description": None, - "required": False, - "settings": { - "type": "multi_label_selection", - "options": [ - {"value": "label-1", "text": "label-1"}, - {"value": "label-2", "text": "label-2"}, - {"value": "label-3", "text": "label-3"}, - ], - "visible_options": 3, - "options_order": LabelsOrder.natural, - }, - "labels": ["label-1", "label-2", "label-3"], - "visible_labels": 3, - }, - ], - "guidelines": "These are the guidelines", - }, - ), - ( - "1.12.0", - { - "fields": [ - { - "name": "field-1", - "title": "Field-1", - "required": True, - "settings": {"type": "text", "use_markdown": False, "use_table": False}, - "use_markdown": False, - }, - { - "name": "field-2", - "title": "Field-2", - "required": False, - "settings": {"type": "text", "use_markdown": False, "use_table": False}, - "use_markdown": False, - }, - ], - "questions": [ - { - "name": "question-1", - "title": "Question-1", - "description": None, - "required": True, - "settings": {"type": "text", "use_markdown": False, "use_table": False}, - "use_markdown": False, - }, - { - "name": "question-2", - "title": "Question-2", - "description": None, - "required": False, - "settings": { - "type": "rating", - "options": [{"value": 1}, {"value": 2}, {"value": 3}, {"value": 4}, {"value": 5}], - }, - "values": [1, 2, 3, 4, 5], - }, - { - "name": "question-3", - "title": "Question-3", - "description": None, - "required": False, - "settings": { - "type": "label_selection", - "options": [ - {"value": "label-1", "text": "label-1"}, - {"value": "label-2", "text": "label-2"}, - {"value": "label-3", "text": "label-3"}, - ], - "visible_options": 3, - }, - "labels": ["label-1", "label-2", "label-3"], - "visible_labels": 3, - }, - { - "name": "question-4", - "title": "Question-4", - "description": None, - "required": False, - "settings": { - "type": "multi_label_selection", - "options": [ - {"value": "label-1", "text": "label-1"}, - {"value": "label-2", "text": "label-2"}, - {"value": "label-3", "text": "label-3"}, - ], - "visible_options": 3, - "options_order": LabelsOrder.natural, - }, - "labels": ["label-1", "label-2", "label-3"], - "visible_labels": 3, - "labels_order": LabelsOrder.natural, - }, - { - "name": "question-5", - "title": "Question-5", - "description": None, - "required": False, - "settings": { - "type": "ranking", - "options": [ - {"value": "label-1", "text": "label-1"}, - {"value": "label-2", "text": "label-2"}, - {"value": "label-3", "text": "label-3"}, - ], - }, - "values": ["label-1", "label-2", "label-3"], - }, - ], - "guidelines": "These are the guidelines", - }, - ), - ), -) -def test_dataset_config_backwards_compatibility_argilla_cfg( - argilla_version: str, outdated_config: Dict[str, Any] -) -> None: - print(f"Loading `argilla.cfg` dumped using `push_to_huggingface` from argilla=={argilla_version}") - config = DeprecatedDatasetConfig.from_json(json.dumps(outdated_config)) - assert isinstance(config, DeprecatedDatasetConfig) - - for field in config.fields: - assert isinstance(field, FieldSchema) - matching_field = next( - (outdated_field for outdated_field in outdated_config["fields"] if outdated_field["name"] == field.name), - None, - ) - assert matching_field is not None - assert field.title == matching_field["title"] - assert field.required == matching_field["required"] - if "settings" in matching_field: - assert field.server_settings == matching_field["settings"] - - for question in config.questions: - assert isinstance(question, QuestionSchema) - matching_question = next( - ( - outdated_question - for outdated_question in outdated_config["questions"] - if outdated_question["name"] == question.name - ), - None, - ) - assert matching_question is not None - assert question.title == matching_question["title"] - assert question.description == matching_question["description"] - assert question.required == matching_question["required"] - if "settings" in matching_question: - if matching_question["settings"]["type"] in ["label_selection", "multi_label_selection"]: - _ = [option.pop("description", None) for option in matching_question["settings"]["options"]] - assert question.server_settings == matching_question["settings"] - - assert config.guidelines == outdated_config["guidelines"] - - -# Same thing but testing the remaining versions and using YAML as -@pytest.mark.parametrize( - "argilla_version, outdated_config", - ( - ( - "1.13.0", - """ - fields: - - id: !!python/object:uuid.UUID - int: 318598997309827170175814937554257429138 - name: field-1 - required: true - settings: - type: text - use_markdown: false - use_table: false - title: Field-1 - type: text - use_markdown: false - - id: !!python/object:uuid.UUID - int: 69686603559390055136114715170048137456 - name: field-2 - required: false - settings: - type: text - use_markdown: false - use_table: false - title: Field-2 - type: text - use_markdown: false - guidelines: These are the guidelines - questions: - - description: null - id: !!python/object:uuid.UUID - int: 50048965276074224092389052083005517643 - name: question-1 - required: true - settings: - type: text - use_markdown: false - use_table: false - title: Question-1 - type: text - use_markdown: false - - description: null - id: !!python/object:uuid.UUID - int: 41875146250353770121043770832765946446 - name: question-2 - required: false - settings: - options: - - value: 1 - - value: 2 - - value: 3 - - value: 4 - - value: 5 - type: rating - title: Question-2 - type: rating - values: - - 1 - - 2 - - 3 - - 4 - - 5 - - description: null - id: !!python/object:uuid.UUID - int: 157923404852454712001576490052121022141 - labels: - - label-1 - - label-2 - - label-3 - name: question-3 - required: false - settings: - options: - - text: label-1 - value: label-1 - - text: label-2 - value: label-2 - - text: label-3 - value: label-3 - type: label_selection - visible_options: 3 - title: Question-3 - type: label_selection - visible_labels: 3 - - description: null - id: !!python/object:uuid.UUID - int: 250432660168731216394809741082680978815 - labels: - - label-1 - - label-2 - - label-3 - name: question-4 - required: false - settings: - options: - - text: label-1 - value: label-1 - - text: label-2 - value: label-2 - - text: label-3 - value: label-3 - type: multi_label_selection - visible_options: 3 - options_order: natural - title: Question-4 - type: multi_label_selection - visible_labels: 3 - labels_order: natural - - description: null - id: !!python/object:uuid.UUID - int: 251163320782812347764238417960223431273 - name: question-5 - required: false - settings: - options: - - text: label-1 - value: label-1 - - text: label-2 - value: label-2 - - text: label-3 - value: label-3 - type: ranking - title: Question-5 - type: ranking - values: - - label-1 - - label-2 - - label-3 - """, - ), - ( - "1.14.0,1.15.0,1.16.0", - """ - fields: - - name: field-1 - required: true - settings: - type: text - use_markdown: false - use_table: false - title: Field-1 - type: text - use_markdown: false - - name: field-2 - required: false - settings: - type: text - use_markdown: false - use_table: false - title: Field-2 - type: text - use_markdown: false - guidelines: These are the guidelines - questions: - - description: null - name: question-1 - required: true - settings: - type: text - use_markdown: false - use_table: false - title: Question-1 - type: text - use_markdown: false - - description: null - name: question-2 - required: false - settings: - options: - - value: 1 - - value: 2 - - value: 3 - - value: 4 - - value: 5 - type: rating - title: Question-2 - type: rating - values: - - 1 - - 2 - - 3 - - 4 - - 5 - - description: null - labels: - - label-1 - - label-2 - - label-3 - name: question-3 - required: false - settings: - options: - - text: label-1 - value: label-1 - - text: label-2 - value: label-2 - - text: label-3 - value: label-3 - type: label_selection - visible_options: 3 - title: Question-3 - type: label_selection - visible_labels: 3 - - description: null - labels: - - label-1 - - label-2 - - label-3 - name: question-4 - required: false - settings: - options: - - text: label-1 - value: label-1 - - text: label-2 - value: label-2 - - text: label-3 - value: label-3 - type: multi_label_selection - visible_options: 3 - options_order: natural - title: Question-4 - type: multi_label_selection - visible_labels: 3 - labels_order: natural - - description: null - name: question-5 - required: false - settings: - options: - - text: label-1 - value: label-1 - - text: label-2 - value: label-2 - - text: label-3 - value: label-3 - type: ranking - title: Question-5 - type: ranking - values: - - label-1 - - label-2 - - label-3 - """, - ), - ), -) -def test_dataset_config_backwards_compatibility_argilla_yaml(argilla_version: str, outdated_config: str) -> None: - print(f"Loading `argilla.yaml` dumped using `push_to_huggingface` from argilla=={argilla_version}") - config = DatasetConfig.from_yaml(outdated_config) - assert isinstance(config, DatasetConfig) - - outdated_config_as_dict = load( - re.sub(r"(\n\s*|)id: !!python/object:uuid\.UUID\s+int: \d+", "", outdated_config), Loader=SafeLoader - ) - assert isinstance(outdated_config_as_dict, dict) - - for field in config.fields: - assert isinstance(field, FieldSchema) - matching_field = next( - ( - outdated_field - for outdated_field in outdated_config_as_dict["fields"] - if outdated_field["name"] == field.name - ), - None, - ) - assert matching_field is not None - assert field.title == matching_field["title"] - assert field.required == matching_field["required"] - if "settings" in matching_field: - assert field.server_settings == matching_field["settings"] - - for question in config.questions: - assert isinstance(question, QuestionSchema) - matching_question = next( - ( - outdated_question - for outdated_question in outdated_config_as_dict["questions"] - if outdated_question["name"] == question.name - ), - None, - ) - assert matching_question is not None - assert question.title == matching_question["title"] - assert question.description == matching_question["description"] - assert question.required == matching_question["required"] - if "settings" in matching_question: - if matching_question["settings"]["type"] in ["label_selection", "multi_label_selection"]: - _ = [option.pop("description", None) for option in matching_question["settings"]["options"]] - assert question.server_settings == matching_question["settings"] - - assert config.guidelines == outdated_config_as_dict["guidelines"] diff --git a/argilla-v1/tests/unit/feedback/training/__init__.py b/argilla-v1/tests/unit/feedback/training/__init__.py deleted file mode 100644 index 55be41799..000000000 --- a/argilla-v1/tests/unit/feedback/training/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2021-present, the Recognai S.L. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/argilla-v1/tests/unit/feedback/training/test_schemas.py b/argilla-v1/tests/unit/feedback/training/test_schemas.py deleted file mode 100644 index b3a730a11..000000000 --- a/argilla-v1/tests/unit/feedback/training/test_schemas.py +++ /dev/null @@ -1,356 +0,0 @@ -# Copyright 2021-present, the Recognai S.L. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pandas as pd -import pytest -import spacy -from argilla_v1 import ( - LabelQuestion, - MultiLabelQuestion, - RankingQuestion, - RatingQuestion, - TextField, -) -from argilla_v1.client.feedback.training.schemas.base import TrainingTask -from argilla_v1.client.feedback.unification import ( - LabelQuestionUnification, - MultiLabelQuestionUnification, - RankingQuestionUnification, - RatingQuestionUnification, -) -from argilla_v1.client.models import Framework -from datasets import Dataset, DatasetDict -from spacy.tokens import DocBin - - -@pytest.mark.parametrize( - "framework, label, train_size, seed, expected", - [ - ( - Framework("spacy"), - RatingQuestionUnification, - 0.5, - None, - (DocBin, DocBin), - ), - ( - Framework("spacy"), - RankingQuestionUnification, - 0.5, - None, - (DocBin, DocBin), - ), - ( - Framework("spacy"), - LabelQuestionUnification, - 0.5, - None, - (DocBin, DocBin), - ), - ( - Framework("spacy"), - MultiLabelQuestionUnification, - 0.5, - None, - (DocBin, DocBin), - ), - ( - Framework("spacy"), - RatingQuestionUnification, - 1, - 42, - DocBin, - ), - ( - Framework("spacy"), - RankingQuestionUnification, - 1, - 42, - DocBin, - ), - (Framework("spacy"), LabelQuestionUnification, 1, 42, DocBin), - ( - Framework("spacy"), - MultiLabelQuestionUnification, - 1, - 42, - DocBin, - ), - ( - Framework("spacy-transformers"), - RatingQuestionUnification, - 0.5, - None, - (DocBin, DocBin), - ), - ( - Framework("spacy-transformers"), - RankingQuestionUnification, - 0.5, - None, - (DocBin, DocBin), - ), - ( - Framework("spacy-transformers"), - LabelQuestionUnification, - 0.5, - None, - (DocBin, DocBin), - ), - ( - Framework("spacy-transformers"), - MultiLabelQuestionUnification, - 0.5, - None, - (DocBin, DocBin), - ), - ( - Framework("spacy-transformers"), - RatingQuestionUnification, - 1, - 42, - DocBin, - ), - ( - Framework("spacy-transformers"), - RankingQuestionUnification, - 1, - 42, - DocBin, - ), - ( - Framework("spacy-transformers"), - LabelQuestionUnification, - 1, - 42, - DocBin, - ), - ( - Framework("spacy-transformers"), - MultiLabelQuestionUnification, - 1, - 42, - DocBin, - ), - ( - Framework("openai"), - RatingQuestionUnification, - 0.5, - None, - (list, list), - ), - ( - Framework("openai"), - RankingQuestionUnification, - 0.5, - None, - (list, list), - ), - ( - Framework("openai"), - LabelQuestionUnification, - 0.5, - None, - (list, list), - ), - ( - Framework("openai"), - MultiLabelQuestionUnification, - 0.5, - None, - (list, list), - ), - ( - Framework("openai"), - RatingQuestionUnification, - 1, - 42, - list, - ), - ( - Framework("openai"), - RankingQuestionUnification, - 1, - 42, - list, - ), - (Framework("openai"), LabelQuestionUnification, 1, 42, list), - ( - Framework("openai"), - MultiLabelQuestionUnification, - 1, - 42, - list, - ), - ( - Framework("transformers"), - RatingQuestionUnification, - 0.5, - None, - DatasetDict, - ), - ( - Framework("transformers"), - RankingQuestionUnification, - 0.5, - None, - DatasetDict, - ), - ( - Framework("transformers"), - LabelQuestionUnification, - 0.5, - None, - DatasetDict, - ), - ( - Framework("transformers"), - MultiLabelQuestionUnification, - 0.5, - None, - DatasetDict, - ), - ( - Framework("transformers"), - RatingQuestionUnification, - 1, - 42, - Dataset, - ), - ( - Framework("transformers"), - RankingQuestionUnification, - 1, - 42, - Dataset, - ), - ( - Framework("transformers"), - LabelQuestionUnification, - 1, - 42, - Dataset, - ), - ( - Framework("transformers"), - MultiLabelQuestionUnification, - 1, - 42, - Dataset, - ), - ( - Framework("spark-nlp"), - RatingQuestionUnification, - 0.5, - None, - (pd.DataFrame, pd.DataFrame), - ), - ( - Framework("spark-nlp"), - RankingQuestionUnification, - 0.5, - None, - (pd.DataFrame, pd.DataFrame), - ), - ( - Framework("spark-nlp"), - LabelQuestionUnification, - 0.5, - None, - (pd.DataFrame, pd.DataFrame), - ), - ( - Framework("spark-nlp"), - MultiLabelQuestionUnification, - 0.5, - None, - (pd.DataFrame, pd.DataFrame), - ), - ( - Framework("spark-nlp"), - RatingQuestionUnification, - 1, - 42, - pd.DataFrame, - ), - ( - Framework("spark-nlp"), - RankingQuestionUnification, - 1, - 42, - pd.DataFrame, - ), - ( - Framework("spark-nlp"), - LabelQuestionUnification, - 1, - 42, - pd.DataFrame, - ), - ( - Framework("spark-nlp"), - MultiLabelQuestionUnification, - 1, - 42, - pd.DataFrame, - ), - ], -) -def test_task_for_text_classification( - framework, - label, - train_size, - seed, - expected, - rating_question_payload, - ranking_question_payload, - label_question_payload, -): - if label == RatingQuestionUnification: - label = RatingQuestionUnification(question=RatingQuestion(**rating_question_payload)) - elif label == RankingQuestionUnification: - label = RankingQuestionUnification(question=RankingQuestion(**ranking_question_payload)) - elif label == LabelQuestionUnification: - label = LabelQuestionUnification(question=LabelQuestion(**label_question_payload)) - elif label == MultiLabelQuestionUnification: - label = MultiLabelQuestionUnification(question=MultiLabelQuestion(**label_question_payload)) - data = [{"text": "This is a text", "label": "1"}, {"text": "This is a text", "label": "2"}] - field = TextField(name="text") - task = TrainingTask.for_text_classification(text=field, label=label) - if framework == Framework.SPACY or framework == Framework.SPACY_TRANSFORMERS: - data = task._prepare_for_training_with_spacy( - data=data, train_size=train_size, seed=seed, lang=spacy.blank("en") - ) - elif framework == Framework.OPENAI: - data = task._prepare_for_training_with_openai(data=data, train_size=train_size, seed=seed) - elif framework == Framework.TRANSFORMERS: - data = task._prepare_for_training_with_transformers( - data=data, train_size=train_size, seed=seed, framework=Framework.TRANSFORMERS - ) - elif framework == Framework.SPARK_NLP: - data = task._prepare_for_training_with_spark_nlp(data=data, train_size=train_size, seed=seed) - else: - raise ValueError(f"Framework {framework} not supported") - if isinstance(data, tuple): - for d, e in zip(data, expected): - assert isinstance(d, e) - else: - assert isinstance(data, expected) - - -def test_training_task_repr(label_question_payload): - field = TextField(name="text") - label = LabelQuestion(**label_question_payload) - task_mapping = TrainingTask.for_text_classification(text=field, label=label) - assert isinstance(repr(task_mapping), str) diff --git a/argilla-v1/tests/unit/feedback/utils/test_assignment.py b/argilla-v1/tests/unit/feedback/utils/test_assignment.py deleted file mode 100644 index c6400dd62..000000000 --- a/argilla-v1/tests/unit/feedback/utils/test_assignment.py +++ /dev/null @@ -1,303 +0,0 @@ -# Copyright 2021-present, the Recognai S.L. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from unittest.mock import Mock, patch - -import pytest -from argilla_v1.client.feedback.utils.assignment import ( - assign_records, - assign_records_to_groups, - assign_records_to_individuals, - assign_workspaces, - check_user, - check_workspace, -) -from argilla_v1.client.users import User -from argilla_v1.client.workspaces import Workspace - - -@pytest.fixture -def mock_user(): - user = Mock(spec=User) - user.username = "test_user" - return user - - -@pytest.fixture -def mock_workspace(): - return Mock(spec=Workspace) - - -@pytest.fixture -def mock_check_user(): - def _mock(user_name): - user = Mock(spec=User) - user.username = user_name - user.id = f"{user_name}_id" - return user - - return _mock - - -@pytest.fixture -def mock_workspace_factory(): - def _factory(*args, **kwargs): - mock = Mock(spec=Workspace) - mock.users = [] - - def create_mock_user(user_id): - user_mock = Mock() - user_mock.id = user_id - return user_mock - - def add_user(user_id): - # Check if a user with this ID already exists in the list - if not any(user.id == user_id for user in mock.users): - mock_user = create_mock_user(user_id) - mock.users.append(mock_user) - - mock.add_user.side_effect = add_user - - return mock - - return _factory - - -@pytest.mark.parametrize( - "input, exists, warning, is_user_obj", - [("existing_user", True, False, False), ("new_user", False, True, False), (mock_user, True, False, True)], -) -@patch("argilla_v1.client.users.User.create") -@patch("argilla_v1.client.users.User.from_name") -def test_check_user(mock_from_name, mock_create, input, exists, warning, is_user_obj, mock_user): - if is_user_obj: - user_input = mock_user - else: - user_input = input - if exists: - mock_from_name.return_value = mock_user - else: - mock_from_name.side_effect = ValueError - mock_create.return_value = mock_user - - result = check_user(user_input) - - assert result is mock_user - - if not exists and not is_user_obj: - mock_create.assert_called_with(username=input, first_name=input, password="12345678", role="annotator") - elif exists and not is_user_obj: - mock_from_name.assert_called_with(input) - - -@pytest.mark.parametrize("workspace_name, workspace_exists", [("existing_workspace", True), ("new_workspace", False)]) -@patch("argilla_v1.client.workspaces.Workspace.from_name") -@patch("argilla_v1.client.workspaces.Workspace.create") -def test_check_workspace(mock_create, mock_from_name, mock_workspace, workspace_name, workspace_exists): - if workspace_exists: - mock_from_name.return_value = mock_workspace - mock_from_name.side_effect = None - else: - mock_from_name.side_effect = ValueError("Workspace does not exist.") - mock_create.return_value = mock_workspace - - result = check_workspace(workspace_name) - - assert result is mock_workspace - if workspace_exists: - mock_from_name.assert_called_with(workspace_name) - else: - mock_create.assert_called_with(workspace_name) - - -@pytest.mark.parametrize( - "overlap, shuffle, expected_error, expected_result", - [ - (1, True, False, None), - (1, False, False, None), - ( - 1, - False, - None, - { - "group1": {"user1": ["record1", "record4"], "user2": ["record1", "record4"]}, - "group2": {"user3": ["record2", "record5"], "user4": ["record2", "record5"]}, - "group3": {"user5": ["record3", "record6"]}, - }, - ), - ( - 2, - False, - None, - { - "group1": { - "user1": ["record1", "record3", "record4", "record6"], - "user2": ["record1", "record3", "record4", "record6"], - }, - "group2": { - "user3": ["record1", "record2", "record4", "record5"], - "user4": ["record1", "record2", "record4", "record5"], - }, - "group3": {"user5": ["record2", "record3", "record5", "record6"]}, - }, - ), - (5, False, ValueError, None), - (-1, False, ValueError, None), - ], -) -@patch("argilla_v1.client.feedback.utils.assignment.random.shuffle") -def test_assign_records_to_groups(mock_shuffle, overlap, shuffle, expected_error, expected_result, mock_check_user): - mock_groups = {"group1": ["user1", "user2"], "group2": ["user3", "user4"], "group3": ["user5"]} - mock_records = ["record1", "record2", "record3", "record4", "record5", "record6"] - - with patch("argilla_v1.client.feedback.utils.assignment.check_user", side_effect=mock_check_user): - if expected_error: - with pytest.raises(expected_error): - assign_records_to_groups(mock_groups, mock_records, overlap, shuffle) - else: - result = assign_records_to_groups(mock_groups, mock_records, overlap, shuffle) - if expected_result is not None: - assert result == expected_result - - if shuffle: - mock_shuffle.assert_called_once_with(mock_records) - else: - mock_shuffle.assert_not_called() - - -@pytest.mark.parametrize( - "overlap, shuffle, expected_error, expected_result", - [ - (1, True, False, None), - (1, False, False, None), - (1, False, None, {"user1": ["record1", "record4"], "user2": ["record2", "record5"], "user3": ["record3"]}), - ( - 2, - False, - None, - { - "user1": ["record1", "record3", "record4"], - "user2": ["record1", "record2", "record4", "record5"], - "user3": ["record2", "record3", "record5"], - }, - ), - (5, False, ValueError, None), - (-1, False, ValueError, None), - ], -) -@patch("argilla_v1.client.feedback.utils.assignment.random.shuffle") -def test_assign_records_to_individuals( - mock_shuffle, overlap, shuffle, expected_error, expected_result, mock_check_user -): - mock_users = [f"user{i}" for i in range(1, 4)] - mock_records = ["record1", "record2", "record3", "record4", "record5"] - - with patch("argilla_v1.client.feedback.utils.assignment.check_user", side_effect=mock_check_user): - if expected_error: - with pytest.raises(expected_error): - assign_records_to_individuals(mock_users, mock_records, overlap, shuffle) - else: - result = assign_records_to_individuals(mock_users, mock_records, overlap, shuffle) - if expected_result is not None: - assert result == expected_result - - if shuffle: - mock_shuffle.assert_called_once_with(mock_records) - else: - mock_shuffle.assert_not_called() - - -@pytest.mark.parametrize( - "input, overlap, shuffle, expected_result", - [ - ( - {"group1": ["user1", "user2"], "group2": ["user3"]}, - 1, - True, - {"group1": {"user1": ["record1", "record2"], "user2": ["record3"]}, "group2": {"user3": ["record4"]}}, - ), - ( - [f"user{i}" for i in range(1, 4)], - 0, - False, - {"user1": ["record1", "record4"], "user2": ["record2"], "user3": ["record3"]}, - ), - ], -) -def test_assign_records(input, overlap, shuffle, expected_result): - mock_records = ["record1", "record2", "record3", "record4", "record5"] - - def mock_assign_records_to_groups(users, records, overlap, shuffle): - assert users == input - assert records == mock_records - assert overlap == overlap - assert shuffle == shuffle - return expected_result - - def mock_assign_records_to_individuals(users, records, overlap, shuffle): - assert users == input - assert records == mock_records - assert overlap == overlap - assert shuffle == shuffle - return expected_result - - with patch( - "argilla_v1.client.feedback.utils.assignment.assign_records_to_groups", - side_effect=mock_assign_records_to_groups, - ): - with patch( - "argilla_v1.client.feedback.utils.assignment.assign_records_to_individuals", - side_effect=mock_assign_records_to_individuals, - ): - result = assign_records(input, mock_records, overlap, shuffle) - assert result == expected_result - - -@pytest.mark.parametrize( - "mock_assignments, assignment_type, expected_result", - [ - ( - { - "group1": {"user1": ["record1", "record3", "record5"], "user2": ["record1", "record3", "record5"]}, - "group2": {"user3": ["record2", "record4"]}, - }, - "group", - {"group1": ["user1", "user2"], "group2": ["user3"]}, - ), - ( - { - "group1": {"user1": ["record1", "record3", "record5"], "user2": ["record1", "record3", "record5"]}, - "group2": {"user3": ["record2", "record4"]}, - }, - "group_personal", - {"user1": ["user1"], "user2": ["user2"], "user3": ["user3"]}, - ), - ( - {"user1": ["record1", "record4"], "user2": ["record2", "record5"], "user3": ["record3"]}, - "individual", - {"user1": ["user1"], "user2": ["user2"], "user3": ["user3"]}, - ), - ], -) -def test_assign_workspaces(mock_check_user, mock_workspace_factory, mock_assignments, assignment_type, expected_result): - with patch("argilla_v1.client.feedback.utils.assignment.check_user", side_effect=mock_check_user): - with patch( - "argilla_v1.client.feedback.utils.assignment.User.from_id", - side_effect=lambda user_id: Mock(username=user_id.split("_")[0]), - ): - with patch( - "argilla_v1.client.feedback.utils.assignment.check_workspace", side_effect=mock_workspace_factory - ): - result = assign_workspaces(mock_assignments, assignment_type) - assert result == expected_result diff --git a/argilla-v1/tests/unit/feedback/utils/test_html_utils.py b/argilla-v1/tests/unit/feedback/utils/test_html_utils.py deleted file mode 100644 index 08866d76d..000000000 --- a/argilla-v1/tests/unit/feedback/utils/test_html_utils.py +++ /dev/null @@ -1,185 +0,0 @@ -# Copyright 2021-present, the Recognai S.L. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import base64 -from unittest import mock - -import pytest -from argilla_v1.client.feedback.utils import ( - audio_to_html, - create_token_highlights, - get_file_data, - image_to_html, - media_to_html, - pdf_to_html, - video_to_html, -) - - -@pytest.mark.parametrize( - "file_source, file_type, media_type, file_exists, file_size, expected_output, expected_exception", - [ - ("path/to/image.jpg", "jpg", "image", True, 1000, (b"sample_data", "jpg"), None), - (b"image_data", "jpg", "image", None, None, (b"image_data", "jpg"), None), - ("path/to/video.mp4", "mp4", "video", True, 2000, (b"sample_data", "mp4"), None), - (b"video_data", None, "video", None, None, None, ValueError), - ("path/to/nonexistent.jpg", "jpg", "image", False, 0, None, FileNotFoundError), - ("path/to/large_file.mp4", "mp4", "video", True, 6_000_000, None, ValueError), - ("path/to/wrong_extension.txt", "jpg", "image", True, 1000, None, ValueError), - ], -) -@mock.patch("pathlib.Path.exists") -@mock.patch("pathlib.Path.stat") -@mock.patch("pathlib.Path.read_bytes") -def test_get_file_data( - mock_read_bytes, - mock_stat, - mock_exists, - file_source, - file_type, - media_type, - file_exists, - file_size, - expected_output, - expected_exception, -): - if isinstance(file_source, str): - mock_exists.return_value = file_exists - mock_stat.return_value = mock.Mock(st_size=file_size) - if expected_exception == ValueError and file_size > 5_000_000: - mock_read_bytes.return_value = b"a" * file_size - else: - mock_read_bytes.return_value = b"sample_data" - - if expected_exception: - with pytest.raises(expected_exception): - get_file_data(file_source, file_type, media_type) - else: - assert get_file_data(file_source, file_type, media_type) == expected_output - - -@pytest.mark.parametrize( - "media_type, file_source, file_type, width, height, autoplay, loop, is_valid_dim, file_data, expected_output, expected_exception", - [ - ( - "image", - "path/to/image.jpg", - "jpeg", - "300px", - "200px", - False, - False, - True, - b"image_data", - '', - None, - ), - ( - "video", - b"video_data", - "mp4", - None, - None, - True, - True, - True, - b"video_data", - "", - None, - ), - ("audio", "path/to/audio.mp3", "mp3", "100%", "invalid", False, False, False, b"audio_data", None, ValueError), - ("document", "path/to/doc.txt", "txt", None, None, False, False, True, b"doc_data", None, ValueError), - ], -) -@mock.patch("argilla_v1.client.feedback.utils.html_utils.get_file_data") -@mock.patch("argilla_v1.client.feedback.utils.html_utils.is_valid_dimension") -def test_media_to_html( - mock_is_valid_dimension, - mock_get_file_data, - media_type, - file_source, - file_type, - width, - height, - autoplay, - loop, - is_valid_dim, - file_data, - expected_output, - expected_exception, -): - mock_is_valid_dimension.return_value = is_valid_dim - mock_get_file_data.return_value = (file_data, file_type) - - if expected_exception: - with pytest.raises(expected_exception): - media_to_html(media_type, file_source, file_type, width, height, autoplay, loop) - else: - assert media_to_html(media_type, file_source, file_type, width, height, autoplay, loop) == expected_output - - -@pytest.mark.parametrize( - "func, input_file, expected", - [ - (video_to_html, "test.mp4", ""), - ( - audio_to_html, - "test.mp3", - "", - ), - (image_to_html, "test.png", ''), - ( - pdf_to_html, - "test.pdf", - '

Unable to display PDF.

', - ), - ( - pdf_to_html, - "https://my_pdf.pdf", - '', - ), - ], -) -def test_wrappers(func, input_file, expected, tmp_path): - if "https://" in input_file: - encoded_data = input_file - assert func(encoded_data) == expected.format(encoded_data) - else: - temp_file = tmp_path / input_file - temp_file.write_bytes(b"dummy_data") - encoded_data = base64.b64encode(b"dummy_data").decode("utf-8") - expected_html = expected.format(encoded_data) - assert func(str(temp_file)) == expected_html - - -@pytest.mark.parametrize( - "tokens,weights,c_map,expected_error,expected_substr", - [ - (["token1", "token2"], [0.2], None, ValueError, None), - (["token1", "token2"], [0.2, 0.8], 42, TypeError, None), - (["token1", "token2"], [0.2, 0.8], lambda x: (1, 0, 0, 1), None, "#ff0000"), - (["token1", "token2"], [0, 0], None, None, "

token1 token2

"), - ([], [], None, ValueError, None), - (["token1"], [0.5], None, None, "token1"), - (["token1", "token2"], [0.5, 0.8], "viridis", None, ' Path: - return Path(__file__).parent / "resources" diff --git a/argilla-v1/tests/unit/labeling/text_classification/resources/weak-supervision-guide-matrix.npy b/argilla-v1/tests/unit/labeling/text_classification/resources/weak-supervision-guide-matrix.npy deleted file mode 100644 index 9cba232eb..000000000 Binary files a/argilla-v1/tests/unit/labeling/text_classification/resources/weak-supervision-guide-matrix.npy and /dev/null differ diff --git a/argilla-v1/tests/unit/labeling/text_classification/test_label_models.py b/argilla-v1/tests/unit/labeling/text_classification/test_label_models.py deleted file mode 100644 index 93e8511ee..000000000 --- a/argilla-v1/tests/unit/labeling/text_classification/test_label_models.py +++ /dev/null @@ -1,788 +0,0 @@ -# coding=utf-8 -# Copyright 2021-present, the Recognai S.L. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import sys - -import numpy as np -import pytest -from argilla_v1.client.models import TextClassificationRecord -from argilla_v1.labeling.text_classification import ( - FlyingSquid, - Snorkel, - WeakLabels, - WeakMultiLabels, -) -from argilla_v1.labeling.text_classification.label_models import ( - LabelModel, - MajorityVoter, - MissingAnnotationError, - NotFittedError, - TieBreakPolicy, - TooFewRulesError, -) - - -@pytest.fixture -def weak_labels(monkeypatch): - def mock_load(*args, **kwargs): - return [ - TextClassificationRecord(text="test", annotation="negative"), - TextClassificationRecord(text="test", annotation="positive"), - TextClassificationRecord(text="test"), - TextClassificationRecord(text="test", annotation="neutral"), - ] - - monkeypatch.setattr("argilla_v1.labeling.text_classification.weak_labels.load", mock_load) - - def mock_apply(self, *args, **kwargs): - weak_label_matrix = np.array( - [[0, 1, -1], [2, 0, -1], [-1, -1, -1], [0, 2, 2]], - dtype=np.short, - ) - annotation_array = np.array([0, 1, -1, 2], dtype=np.short) - label2int = {None: -1, "negative": 0, "positive": 1, "neutral": 2} - return weak_label_matrix, annotation_array, label2int - - monkeypatch.setattr(WeakLabels, "_apply_rules", mock_apply) - - return WeakLabels(rules=[lambda: None] * 3, dataset="mock") - - -@pytest.fixture -def weak_labels_from_guide(monkeypatch, resources): - matrix_and_annotation = np.load(str(resources / "weak-supervision-guide-matrix.npy")) - matrix, annotation = matrix_and_annotation[:, :-1], matrix_and_annotation[:, -1] - - def mock_load(*args, **kwargs): - return [TextClassificationRecord(text="mock", id=i) for i in range(len(matrix))] - - monkeypatch.setattr("argilla_v1.labeling.text_classification.weak_labels.load", mock_load) - - def mock_apply(self, *args, **kwargs): - return matrix, annotation, {None: -1, "SPAM": 0, "HAM": 1} - - monkeypatch.setattr(WeakLabels, "_apply_rules", mock_apply) - - return WeakLabels(rules=[lambda x: "mock"] * matrix.shape[1], dataset="mock") - - -@pytest.fixture -def weak_multi_labels(monkeypatch): - def mock_load(*args, **kwargs): - return [ - TextClassificationRecord(text="test", multi_label=True, annotation=["scared"]), - TextClassificationRecord(text="test", multi_label=True, annotation=["sad", "scared"]), - TextClassificationRecord(text="test", multi_label=True, annotation=[]), - TextClassificationRecord(text="test", multi_label=True), - ] - - monkeypatch.setattr("argilla_v1.labeling.text_classification.weak_labels.load", mock_load) - - def mock_apply(self, *args, **kwargs): - weak_label_matrix = np.array( - [ - [[0, 0, 1], [-1, -1, -1]], - [[0, 1, 1], [1, 0, 1]], - [[-1, -1, -1], [-1, -1, -1]], - [[0, 0, 0], [0, 0, 0]], - ], - dtype=np.byte, - ) - annotation_array = np.array([[0, 0, 1], [1, 0, 1], [0, 0, 0], [-1, -1, -1]], dtype=np.byte) - labels = ["sad", "happy", "scared"] - return weak_label_matrix, annotation_array, labels - - monkeypatch.setattr(WeakMultiLabels, "_apply_rules", mock_apply) - - return WeakMultiLabels(rules=[lambda: None] * 2, dataset="mock") - - -def test_tie_break_policy_enum(): - with pytest.raises(ValueError, match="mock is not a valid TieBreakPolicy"): - TieBreakPolicy("mock") - - -class TestLabelModel: - def test_weak_label_property(self): - weak_labels = object() - label_model = LabelModel(weak_labels) - - assert label_model.weak_labels is weak_labels - - def test_abstract_methods(self): - label_model = LabelModel(None) - with pytest.raises(NotImplementedError): - label_model.fit() - with pytest.raises(NotImplementedError): - label_model.score() - with pytest.raises(NotImplementedError): - label_model.predict() - - -class TestMajorityVoter: - def test_no_need_to_fit_error(self): - mj = MajorityVoter(None) - with pytest.raises(NotImplementedError, match="No need to call"): - mj.fit() - - @pytest.mark.parametrize( - "wls, include_annotated_records, expected", - [ - ("weak_labels", True, 4), - ("weak_labels", False, 1), - ("weak_multi_labels", True, 4), - ("weak_multi_labels", False, 1), - ], - ) - def test_predict(self, monkeypatch, request, wls, include_annotated_records, expected): - def compute_probs(self, wl_matrix, **kwargs): - assert len(wl_matrix) == expected - compute_probs.called = None - - def make_records(self, probabilities, records, **kwargs): - assert probabilities is None - return records - - single_or_multi = "multi" if wls == "weak_multi_labels" else "single" - monkeypatch.setattr(MajorityVoter, f"_compute_{single_or_multi}_label_probs", compute_probs) - monkeypatch.setattr(MajorityVoter, f"_make_{single_or_multi}_label_records", make_records) - - weak_labels = request.getfixturevalue(wls) - mj = MajorityVoter(weak_labels) - - assert len(mj.predict(include_annotated_records=include_annotated_records)) == expected - assert hasattr(compute_probs, "called") - - def test_compute_single_label_probs(self, weak_labels): - mj = MajorityVoter(weak_labels) - probs = mj._compute_single_label_probs(weak_labels.matrix()) - - expected = np.array( - [ - [0.5, 0.5, 0.0], - [0.5, 0, 0.5], - [1.0 / 3, 1.0 / 3, 1.0 / 3], - [1.0 / 3, 0.0, 2.0 / 3], - ] - ) - assert np.allclose(probs, expected) - - @pytest.mark.parametrize( - "include_abstentions,tie_break_policy,expected", - [ - (True, TieBreakPolicy.ABSTAIN, 4), - (False, TieBreakPolicy.ABSTAIN, 1), - (True, TieBreakPolicy.RANDOM, 4), - (False, TieBreakPolicy.RANDOM, 4), - ], - ) - def test_make_single_label_records(self, weak_labels, include_abstentions, tie_break_policy, expected): - mj = MajorityVoter(weak_labels) - probs = mj._compute_single_label_probs(weak_labels.matrix()) - - records = mj._make_single_label_records( - probs, - weak_labels.records(), - include_abstentions, - prediction_agent="mock", - tie_break_policy=tie_break_policy, - ) - - assert records[-1].prediction_agent == "mock" - assert records[-1].prediction == [ - ("neutral", 2.0 / 3), - ("negative", 1.0 / 3), - ("positive", 0.0), - ] - assert len(records) == expected - if include_abstentions and tie_break_policy is TieBreakPolicy.ABSTAIN: - assert all(rec.prediction is None for rec in records[:3]) - if tie_break_policy is TieBreakPolicy.RANDOM: - assert records[2].prediction == [ - ("negative", 1.0 / 3 + 0.0001), - ("neutral", 1.0 / 3 - 0.00005), - ("positive", 1.0 / 3 - 0.00005), - ] - - def test_make_single_label_records_with_not_implemented_tbp(self, weak_labels): - mj = MajorityVoter(weak_labels) - probs = mj._compute_single_label_probs(weak_labels.matrix()) - - with pytest.raises( - NotImplementedError, - match="tie break policy 'true-random' is not implemented", - ): - mj._make_single_label_records( - probs, - weak_labels.records(), - True, - prediction_agent="mock", - tie_break_policy=TieBreakPolicy.TRUE_RANDOM, - ) - - def test_compute_multi_label_probs(self, weak_multi_labels): - mj = MajorityVoter(weak_multi_labels) - probabilities = mj._compute_multi_label_probs(weak_multi_labels.matrix()) - - expected = np.array( - [[0, 0, 1], [1, 1, 1], [np.nan, np.nan, np.nan], [0, 0, 0]], - dtype=np.float16, - ) - assert np.allclose(probabilities, expected, equal_nan=True) - - @pytest.mark.parametrize( - "include_abstentions,expected", - [ - (True, 4), - (False, 3), - ], - ) - def test_make_multi_label_records(self, weak_multi_labels, include_abstentions, expected): - mj = MajorityVoter(weak_multi_labels) - probs = mj._compute_multi_label_probs(weak_multi_labels.matrix()) - - records = mj._make_multi_label_records( - probs, - weak_multi_labels.records(), - include_abstentions, - prediction_agent="mock", - ) - - assert records[0].prediction_agent == "mock" - assert records[0].prediction == [ - ("scared", 1.0), - ("happy", 0.0), - ("sad", 0.0), - ] - assert len(records) == expected - if include_abstentions: - assert records[2].prediction is None - - def test_score_sklearn_not_installed(self, monkeypatch, weak_labels): - monkeypatch.setattr(sys, "meta_path", [], raising=False) - - mj = MajorityVoter(weak_labels) - with pytest.raises(ModuleNotFoundError, match="pip install scikit-learn"): - mj.score() - - @pytest.mark.parametrize( - "wls, output_str", - [ - ("weak_labels", True), - ("weak_multi_labels", False), - ], - ) - def test_score(self, monkeypatch, request, wls, output_str): - def compute_probs(self, wl_matrix, **kwargs): - compute_probs.called = None - - def score(self, probabilities, tie_break_policy=None): - assert probabilities is None - if wls == "weak_labels": - assert tie_break_policy == TieBreakPolicy.ABSTAIN - return np.array([1, 0]), np.array([1, 1]) - - assert tie_break_policy is None - return np.array([[1, 1, 1], [0, 0, 0]]), np.array([[1, 1, 1], [1, 0, 0]]) - - single_or_multi = "multi" if wls == "weak_multi_labels" else "single" - monkeypatch.setattr(MajorityVoter, f"_compute_{single_or_multi}_label_probs", compute_probs) - monkeypatch.setattr(MajorityVoter, f"_score_{single_or_multi}_label", score) - - weak_labels = request.getfixturevalue(wls) - score = MajorityVoter(weak_labels).score(output_str=output_str) - if output_str: - assert isinstance(score, str) - else: - assert isinstance(score, dict) - assert "sad" in score and "happy" in score - assert hasattr(compute_probs, "called") - - @pytest.mark.parametrize( - "tie_break_policy, expected", - [ - (TieBreakPolicy.ABSTAIN, (np.array([2]), np.array([2]))), - (TieBreakPolicy.RANDOM, (np.array([0, 1, 2]), np.array([0, 2, 2]))), - (TieBreakPolicy.TRUE_RANDOM, None), - ], - ) - def test_score_single_label(self, weak_labels, tie_break_policy, expected): - mj = MajorityVoter(weak_labels) - - probabilities = np.array([[0.5, 0.5, 0.0], [0.5, 0.0, 0.5], [1.0 / 3, 0.0, 2.0 / 3]]) - - if tie_break_policy is TieBreakPolicy.TRUE_RANDOM: - with pytest.raises(NotImplementedError, match="not implemented for MajorityVoter"): - mj._score_single_label(probabilities, tie_break_policy) - return - - annotation, prediction = mj._score_single_label(probabilities=probabilities, tie_break_policy=tie_break_policy) - assert np.allclose(annotation, expected[0]) - assert np.allclose(prediction, expected[1]) - - def test_score_single_label_no_ties(self, weak_labels): - mj = MajorityVoter(weak_labels) - - probabilities = np.array([[0.5, 0.3, 0.0], [0.5, 0.0, 0.0], [1.0 / 3, 0.0, 2.0 / 3]]) - - _, prediction = mj._score_single_label(probabilities=probabilities, tie_break_policy=TieBreakPolicy.ABSTAIN) - _, prediction2 = mj._score_single_label(probabilities=probabilities, tie_break_policy=TieBreakPolicy.RANDOM) - - assert np.allclose(prediction, prediction2) - - def test_score_multi_label(self, weak_multi_labels): - mj = MajorityVoter(weak_multi_labels) - - probabilities = np.array([[0.0, 0.0, 1.0], [1.0, 1.0, 1.0], [np.nan, np.nan, np.nan]]) - - annotation, prediction = mj._score_multi_label(probabilities=probabilities) - - assert np.allclose(annotation, np.array([[0, 0, 1], [1, 0, 1]])) - assert np.allclose(prediction, np.array([[0, 0, 1], [1, 1, 1]])) - - -class TestSnorkel: - def test_not_installed(self, monkeypatch): - monkeypatch.setattr(sys, "meta_path", [], raising=False) - with pytest.raises(ModuleNotFoundError, match="pip install snorkel"): - Snorkel(None) - - def test_init(self, weak_labels): - from snorkel.labeling.model import LabelModel as SnorkelLabelModel - - label_model = Snorkel(weak_labels) - - assert label_model.weak_labels is weak_labels - assert isinstance(label_model._model, SnorkelLabelModel) - assert label_model._model.cardinality == 3 - - @pytest.mark.parametrize( - "wrong_mapping,expected", - [ - ( - {None: -10, "negative": 0, "positive": 1, "neutral": 2}, - {-10: -1, 0: 0, 1: 1, 2: 2}, - ), - ( - {None: -1, "negative": 1, "positive": 3, "neutral": 4}, - {-1: -1, 1: 0, 3: 1, 4: 2}, - ), - ], - ) - def test_init_wrong_mapping(self, weak_labels, wrong_mapping, expected): - weak_labels.change_mapping(wrong_mapping) - label_model = Snorkel(weak_labels) - - assert label_model._weaklabelsInt2snorkelInt == expected - assert label_model._snorkelInt2weaklabelsInt == {k: v for v, k in expected.items()} - - @pytest.mark.parametrize( - "include_annotated_records", - [True, False], - ) - def test_fit(self, monkeypatch, weak_labels, include_annotated_records): - def mock_fit(self, L_train, *args, **kwargs): - if include_annotated_records: - assert (L_train == weak_labels.matrix()).all() - else: - assert (L_train == weak_labels.matrix(has_annotation=False)).all() - assert kwargs == {"passed_on": None} - - monkeypatch.setattr( - "snorkel.labeling.model.LabelModel.fit", - mock_fit, - ) - - label_model = Snorkel(weak_labels) - label_model.fit(include_annotated_records=include_annotated_records, passed_on=None) - - def test_fit_automatically_added_kwargs(self, weak_labels): - label_model = Snorkel(weak_labels) - with pytest.raises(ValueError, match="provided automatically"): - label_model.fit(L_train=None) - - @pytest.mark.parametrize( - "policy,include_annotated_records,include_abstentions,expected", - [ - ("abstain", True, False, (2, ["positive", "negative"], [0.8, 0.9])), - ( - "abstain", - True, - True, - (4, [None, None, "positive", "negative"], [None, None, 0.8, 0.9]), - ), - ("random", False, True, (1, ["positive"], [0.8])), - ( - "random", - True, - True, - ( - 4, - ["positive", "negative", "positive", "negative"], - [0.4 + 0.0001, 1.0 / 3 + 0.0001, 0.8, 0.9], - ), - ), - ], - ) - def test_predict( - self, - weak_labels, - monkeypatch, - policy, - include_annotated_records, - include_abstentions, - expected, - ): - def mock_predict(self, L, return_probs, tie_break_policy, *args, **kwargs): - assert tie_break_policy == policy - assert return_probs is True - if include_annotated_records: - assert len(L) == 4 - preds = np.array([-1, -1, 1, 0]) - if policy == "random": - preds = np.array([1, 0, 1, 0]) - return preds, np.array( - [ - [0.4, 0.4, 0.2], - [1.0 / 3, 1.0 / 3, 1.0 / 3], - [0.1, 0.8, 0.1], - [0.9, 0.05, 0.05], - ] - ) - else: - assert len(L) == 1 - return np.array([1]), np.array([[0.1, 0.8, 0.1]]) - - monkeypatch.setattr( - "snorkel.labeling.model.LabelModel.predict", - mock_predict, - ) - - label_model = Snorkel(weak_labels) - - records = label_model.predict( - tie_break_policy=policy, - include_annotated_records=include_annotated_records, - include_abstentions=include_abstentions, - prediction_agent="mock_agent", - ) - assert len(records) == expected[0] - assert [rec.prediction[0][0] if rec.prediction else None for rec in records] == expected[1] - assert [rec.prediction[0][1] if rec.prediction else None for rec in records] == expected[2] - assert records[0].prediction_agent == "mock_agent" - - @pytest.mark.parametrize("policy,expected", [("abstain", 0.5), ("random", 2.0 / 3)]) - def test_score(self, monkeypatch, weak_labels, policy, expected): - def mock_predict(self, L, return_probs, tie_break_policy): - assert (L == weak_labels.matrix(has_annotation=True)).all() - assert return_probs is True - assert tie_break_policy == policy - if policy == "abstain": - predictions = np.array([-1, 1, 0]) - elif policy == "random": - predictions = np.array([0, 1, 0]) - else: - raise ValueError("Untested policy!") - - probabilities = None # accuracy does not need probabs ... - - return predictions, probabilities - - monkeypatch.setattr( - "snorkel.labeling.model.LabelModel.predict", - mock_predict, - ) - - label_model = Snorkel(weak_labels) - metrics = label_model.score(tie_break_policy=policy) - - assert metrics["accuracy"] == pytest.approx(expected) - assert list(metrics.keys())[:3] == ["negative", "positive", "neutral"] - - def test_score_without_annotations(self, weak_labels): - weak_labels._annotation = np.array([], dtype=np.short) - label_model = Snorkel(weak_labels) - - with pytest.raises(MissingAnnotationError, match="need annotated records"): - label_model.score() - - @pytest.mark.parametrize( - "change_mapping", - [False, True], - ) - def test_integration(self, weak_labels_from_guide, change_mapping): - if change_mapping: - weak_labels_from_guide.change_mapping({None: -10, "HAM": 2, "SPAM": 5}) - label_model = Snorkel(weak_labels_from_guide) - label_model.fit(seed=43) - - metrics = label_model.score() - assert metrics["accuracy"] == pytest.approx(0.8947368421052632) - - records = label_model.predict() - assert len(records) == 1177 - assert records[0].prediction == [ - ("SPAM", pytest.approx(0.5633776670811805)), - ("HAM", pytest.approx(0.4366223329188196)), - ] - - -class TestFlyingSquid: - def test_not_installed(self, monkeypatch): - monkeypatch.setattr(sys, "meta_path", [], raising=False) - with pytest.raises(ModuleNotFoundError, match="pip install flyingsquid"): - FlyingSquid(None) - - def test_init(self, weak_labels): - FlyingSquid(weak_labels) - - with pytest.raises(ValueError, match="must not contain 'm'"): - FlyingSquid(weak_labels, m="mock") - - weak_labels._rules = weak_labels.rules[:2] - with pytest.raises(TooFewRulesError, match="at least three"): - FlyingSquid(weak_labels) - - @pytest.mark.parametrize("include_annotated,expected", [(False, 1), (True, 4)]) - def test_fit(self, monkeypatch, weak_labels, include_annotated, expected): - def mock_fit(*args, **kwargs): - if not include_annotated: - assert (kwargs["L_train"] == np.array([0, 0, 0])).all() - assert len(kwargs["L_train"]) == expected - - monkeypatch.setattr( - "flyingsquid.label_model.LabelModel.fit", - mock_fit, - ) - - label_model = FlyingSquid(weak_labels) - label_model.fit(include_annotated_records=include_annotated) - - assert len(label_model._models) == 3 - - def test_fit_init_kwargs(self, monkeypatch, weak_labels): - class MockLabelModel: - def __init__(self, m, mock): - assert m == len(weak_labels.rules) - assert mock == "mock" - - def fit(self, L_train, mock): - assert mock == "mock_fit_kwargs" - - monkeypatch.setattr( - "flyingsquid.label_model.LabelModel", - MockLabelModel, - ) - - label_model = FlyingSquid(weak_labels, mock="mock") - label_model.fit(mock="mock_fit_kwargs") - - @pytest.mark.parametrize( - "policy,include_annotated_records,include_abstentions,verbose,expected", - [ - ( - "abstain", - False, - False, - True, - { - "verbose": True, - "L_matrix_length": 1, - "return": np.array([[0.5, 0.5]]), - "nr_of_records": 0, - }, - ), - ( - "abstain", - True, - True, - False, - { - "verbose": False, - "L_matrix_length": 4, - "return": np.array([[0.5, 0.5] * 4]), - "nr_of_records": 4, - "prediction": None, - }, - ), - ( - "random", - False, - False, - False, - { - "verbose": False, - "L_matrix_length": 1, - "return": np.array([[0.5, 0.5]]), - "nr_of_records": 1, - "prediction": [ - ("negative", 0.3334333333333333), - ("neutral", 0.3332833333333333), - ("positive", 0.3332833333333333), - ], - }, - ), - ], - ) - def test_predict( - self, - weak_labels, - monkeypatch, - policy, - include_annotated_records, - include_abstentions, - verbose, - expected, - ): - class MockPredict: - calls_count = 0 - - @classmethod - def __call__(cls, L_matrix, verbose): - assert verbose is expected["verbose"] - assert len(L_matrix) == expected["L_matrix_length"] - cls.calls_count += 1 - - return expected["return"] - - monkeypatch.setattr( - "flyingsquid.label_model.LabelModel.predict_proba", - MockPredict(), - ) - - label_model = FlyingSquid(weak_labels) - label_model.fit() - - records = label_model.predict( - tie_break_policy=policy, - include_annotated_records=include_annotated_records, - include_abstentions=include_abstentions, - verbose=verbose, - prediction_agent="mock_agent", - ) - - assert MockPredict.calls_count == 3 - assert len(records) == expected["nr_of_records"] - if records: - assert records[0].prediction == expected["prediction"] - assert records[0].prediction_agent == "mock_agent" - - def test_predict_binary(self, monkeypatch, weak_labels): - class MockPredict: - calls_count = 0 - - @classmethod - def __call__(cls, L_matrix, verbose): - cls.calls_count += 1 - return np.array([[0.6, 0.4]]) - - monkeypatch.setattr( - "flyingsquid.label_model.LabelModel.predict_proba", - MockPredict(), - ) - - weak_labels._label2int = {None: -1, "negative": 0, "positive": 1} - - label_model = FlyingSquid(weak_labels=weak_labels) - label_model.fit() - - records = label_model.predict() - - assert MockPredict.calls_count == 1 - assert len(records) == 1 - assert records[0].prediction == [("negative", 0.6), ("positive", 0.4)] - - def test_predict_not_implented_tbp(self, weak_labels): - label_model = FlyingSquid(weak_labels) - label_model.fit() - - with pytest.raises(NotImplementedError, match="true-random"): - label_model.predict(tie_break_policy="true-random") - - def test_predict_not_fitted_error(self, weak_labels): - label_model = FlyingSquid(weak_labels) - with pytest.raises(NotFittedError, match="not fitted yet"): - label_model.predict() - - def test_score_not_fitted_error(self, weak_labels): - label_model = FlyingSquid(weak_labels) - with pytest.raises(NotFittedError, match="not fitted yet"): - label_model.score() - - def test_score_sklearn_not_installed(self, monkeypatch: pytest.MonkeyPatch, weak_labels): - label_model = FlyingSquid(weak_labels) - - monkeypatch.setattr(sys, "meta_path", [], raising=False) - with pytest.raises(ModuleNotFoundError, match="pip install scikit-learn"): - label_model.score() - - def test_score(self, monkeypatch, weak_labels): - def mock_predict(weak_label_matrix, verbose): - assert verbose is False - assert len(weak_label_matrix) == 3 - return np.array([[0.8, 0.1, 0.1], [0.1, 0.8, 0.1], [0.1, 0.1, 0.8]]) - - label_model = FlyingSquid(weak_labels) - # We have to monkeypatch the instance rather than the class due to decorators - # on the class - monkeypatch.setattr(label_model, "_predict", mock_predict) - metrics = label_model.score() - - assert "accuracy" in metrics - assert metrics["accuracy"] == pytest.approx(1.0) - assert list(metrics.keys())[:3] == ["negative", "positive", "neutral"] - - assert isinstance(label_model.score(output_str=True), str) - - @pytest.mark.parametrize("tbp,vrb,expected", [("abstain", False, 1.0), ("random", True, 2 / 3.0)]) - def test_score_tbp(self, monkeypatch, weak_labels, tbp, vrb, expected): - def mock_predict(weak_label_matrix, verbose): - assert verbose is vrb - assert len(weak_label_matrix) == 3 - return np.array([[0.8, 0.1, 0.1], [0.4, 0.4, 0.2], [1 / 3.0, 1 / 3.0, 1 / 3.0]]) - - label_model = FlyingSquid(weak_labels) - - monkeypatch.setattr(label_model, "_predict", mock_predict) - - metrics = label_model.score(tie_break_policy=tbp, verbose=vrb) - - assert metrics["accuracy"] == pytest.approx(expected) - if tbp == "abstain": - assert list(metrics.keys())[:1] == ["negative"] - elif tbp == "random": - assert list(metrics.keys())[:3] == ["negative", "positive", "neutral"] - - def test_score_not_implemented_tbp(self, weak_labels): - label_model = FlyingSquid(weak_labels) - label_model.fit() - - with pytest.raises(NotImplementedError, match="true-random"): - label_model.score(tie_break_policy="true-random") - - def test_integration(self, weak_labels_from_guide): - label_model = FlyingSquid(weak_labels_from_guide) - label_model.fit() - - metrics = label_model.score() - assert metrics["accuracy"] == pytest.approx(0.9282296650717703) - - records = label_model.predict() - assert len(records) == 1177 - - prediction = records[0].prediction - spam_prediction_probability = prediction[0][1] - ham_prediction_probability = prediction[1][1] - assert spam_prediction_probability == pytest.approx(0.8236983486087645) - assert ham_prediction_probability == pytest.approx(0.17630165139123552) diff --git a/codecov.yml b/codecov.yml index 2ec5f38bd..b6d20a2c9 100644 --- a/codecov.yml +++ b/codecov.yml @@ -12,7 +12,7 @@ coverage: target: auto threshold: 2% informational: true - + flags: frontend: paths: @@ -20,11 +20,11 @@ flags: carryforward: true argilla: paths: - - argilla/src/argilla/ + - extralit/src/argilla/ carryforward: true extralit: paths: - - argilla/src/extralit/ + - extralit/src/extralit/ carryforward: true argilla_v1: paths: @@ -40,16 +40,16 @@ component_management: statuses: - type: project target: auto - + individual_components: - component_id: extralit paths: - - argilla/src/extralit/** + - extralit/src/extralit/** - component_id: argilla name: argilla paths: - - argilla/src/argilla/** + - extralit/src/argilla/** - component_id: argilla_v1 name: argilla_v1 @@ -59,7 +59,7 @@ component_management: - component_id: argilla-server paths: - argilla-server/** - + - component_id: argilla-frontend paths: - argilla-frontend/** diff --git a/argilla/.env.test b/extralit/.env.test similarity index 100% rename from argilla/.env.test rename to extralit/.env.test diff --git a/argilla/.gitignore b/extralit/.gitignore similarity index 100% rename from argilla/.gitignore rename to extralit/.gitignore diff --git a/argilla/CHANGELOG.md b/extralit/CHANGELOG.md similarity index 99% rename from argilla/CHANGELOG.md rename to extralit/CHANGELOG.md index a74bcc2f8..e0e3d6127 100644 --- a/argilla/CHANGELOG.md +++ b/extralit/CHANGELOG.md @@ -19,6 +19,11 @@ These are the section headers that we use: ### Changed - Updated `upload_file` function to streamline file upload process and improve user feedback. - Modified document listing and file upload functionalities for better user experience and feedback. +- Refactored `argilla/*` to `extralit/*` to align with the new project structure. +- Changed `EXTRALIT_CACHE_DIR` to `~/.extralit/` from `~/.argilla/` to align with new project structure. + +### Deprecated +- `argilla-v1` is deprecated and will be removed in the next major release. Use `extralit` instead. ### Fixed - Fixed all integration tests. @@ -26,7 +31,7 @@ These are the section headers that we use: - Enhanced test failure handling and updated test commands to suppress warnings. - Handle dataset not found errors in Hugging Face dataset tests. - Update spaCy and pyarrow dependencies for Python version compatibility. -- Update `argilla.yml` on Python 3.13. +- Update `extralit.yml` on Python 3.13. - Update spaCy and spaCy-wheel version constraints for compatibility. ## [Argilla] [2.8.0](https://github.com/argilla-io/argilla/compare/v2.6.0...v2.8.0) diff --git a/argilla/LICENSE b/extralit/LICENSE similarity index 100% rename from argilla/LICENSE rename to extralit/LICENSE diff --git a/argilla/LICENSE_HEADER b/extralit/LICENSE_HEADER similarity index 100% rename from argilla/LICENSE_HEADER rename to extralit/LICENSE_HEADER diff --git a/argilla/README.md b/extralit/README.md similarity index 100% rename from argilla/README.md rename to extralit/README.md diff --git a/argilla/docker/extralit.dockerfile b/extralit/docker/extralit.dockerfile similarity index 100% rename from argilla/docker/extralit.dockerfile rename to extralit/docker/extralit.dockerfile diff --git a/argilla/docs/admin_guide/annotate.md b/extralit/docs/admin_guide/annotate.md similarity index 100% rename from argilla/docs/admin_guide/annotate.md rename to extralit/docs/admin_guide/annotate.md diff --git a/argilla/docs/admin_guide/custom_fields.md b/extralit/docs/admin_guide/custom_fields.md similarity index 100% rename from argilla/docs/admin_guide/custom_fields.md rename to extralit/docs/admin_guide/custom_fields.md diff --git a/argilla/docs/admin_guide/dataset.md b/extralit/docs/admin_guide/dataset.md similarity index 100% rename from argilla/docs/admin_guide/dataset.md rename to extralit/docs/admin_guide/dataset.md diff --git a/argilla/docs/admin_guide/distribution.md b/extralit/docs/admin_guide/distribution.md similarity index 100% rename from argilla/docs/admin_guide/distribution.md rename to extralit/docs/admin_guide/distribution.md diff --git a/argilla/docs/admin_guide/docker_deployment.md b/extralit/docs/admin_guide/docker_deployment.md similarity index 100% rename from argilla/docs/admin_guide/docker_deployment.md rename to extralit/docs/admin_guide/docker_deployment.md diff --git a/argilla/docs/admin_guide/import_export.md b/extralit/docs/admin_guide/import_export.md similarity index 100% rename from argilla/docs/admin_guide/import_export.md rename to extralit/docs/admin_guide/import_export.md diff --git a/argilla/docs/admin_guide/index.md b/extralit/docs/admin_guide/index.md similarity index 100% rename from argilla/docs/admin_guide/index.md rename to extralit/docs/admin_guide/index.md diff --git a/argilla/docs/admin_guide/k8s_deployment.md b/extralit/docs/admin_guide/k8s_deployment.md similarity index 100% rename from argilla/docs/admin_guide/k8s_deployment.md rename to extralit/docs/admin_guide/k8s_deployment.md diff --git a/argilla/docs/admin_guide/migrate_from_legacy_datasets.md b/extralit/docs/admin_guide/migrate_from_legacy_datasets.md similarity index 100% rename from argilla/docs/admin_guide/migrate_from_legacy_datasets.md rename to extralit/docs/admin_guide/migrate_from_legacy_datasets.md diff --git a/argilla/docs/admin_guide/query.md b/extralit/docs/admin_guide/query.md similarity index 100% rename from argilla/docs/admin_guide/query.md rename to extralit/docs/admin_guide/query.md diff --git a/argilla/docs/admin_guide/record.md b/extralit/docs/admin_guide/record.md similarity index 100% rename from argilla/docs/admin_guide/record.md rename to extralit/docs/admin_guide/record.md diff --git a/argilla/docs/admin_guide/upgrading.md b/extralit/docs/admin_guide/upgrading.md similarity index 100% rename from argilla/docs/admin_guide/upgrading.md rename to extralit/docs/admin_guide/upgrading.md diff --git a/argilla/docs/admin_guide/use_markdown_to_format_rich_content.md b/extralit/docs/admin_guide/use_markdown_to_format_rich_content.md similarity index 100% rename from argilla/docs/admin_guide/use_markdown_to_format_rich_content.md rename to extralit/docs/admin_guide/use_markdown_to_format_rich_content.md diff --git a/argilla/docs/admin_guide/user.md b/extralit/docs/admin_guide/user.md similarity index 100% rename from argilla/docs/admin_guide/user.md rename to extralit/docs/admin_guide/user.md diff --git a/argilla/docs/admin_guide/webhooks.md b/extralit/docs/admin_guide/webhooks.md similarity index 100% rename from argilla/docs/admin_guide/webhooks.md rename to extralit/docs/admin_guide/webhooks.md diff --git a/argilla/docs/admin_guide/webhooks_internals.md b/extralit/docs/admin_guide/webhooks_internals.md similarity index 100% rename from argilla/docs/admin_guide/webhooks_internals.md rename to extralit/docs/admin_guide/webhooks_internals.md diff --git a/argilla/docs/admin_guide/workspace.md b/extralit/docs/admin_guide/workspace.md similarity index 100% rename from argilla/docs/admin_guide/workspace.md rename to extralit/docs/admin_guide/workspace.md diff --git a/argilla/docs/assets/favicon.ico b/extralit/docs/assets/favicon.ico similarity index 100% rename from argilla/docs/assets/favicon.ico rename to extralit/docs/assets/favicon.ico diff --git a/argilla/docs/assets/images/community/contributing/argilla-slack.png b/extralit/docs/assets/images/community/contributing/argilla-slack.png similarity index 100% rename from argilla/docs/assets/images/community/contributing/argilla-slack.png rename to extralit/docs/assets/images/community/contributing/argilla-slack.png diff --git a/argilla/docs/assets/images/community/contributing/channels.PNG b/extralit/docs/assets/images/community/contributing/channels.PNG similarity index 100% rename from argilla/docs/assets/images/community/contributing/channels.PNG rename to extralit/docs/assets/images/community/contributing/channels.PNG diff --git a/argilla/docs/assets/images/community/contributing/compare-across-forks.PNG b/extralit/docs/assets/images/community/contributing/compare-across-forks.PNG similarity index 100% rename from argilla/docs/assets/images/community/contributing/compare-across-forks.PNG rename to extralit/docs/assets/images/community/contributing/compare-across-forks.PNG diff --git a/argilla/docs/assets/images/community/contributing/compare-pull-request.PNG b/extralit/docs/assets/images/community/contributing/compare-pull-request.PNG similarity index 100% rename from argilla/docs/assets/images/community/contributing/compare-pull-request.PNG rename to extralit/docs/assets/images/community/contributing/compare-pull-request.PNG diff --git a/argilla/docs/assets/images/community/contributing/create-branch-together.png b/extralit/docs/assets/images/community/contributing/create-branch-together.png similarity index 100% rename from argilla/docs/assets/images/community/contributing/create-branch-together.png rename to extralit/docs/assets/images/community/contributing/create-branch-together.png diff --git a/argilla/docs/assets/images/community/contributing/create-branch.PNG b/extralit/docs/assets/images/community/contributing/create-branch.PNG similarity index 100% rename from argilla/docs/assets/images/community/contributing/create-branch.PNG rename to extralit/docs/assets/images/community/contributing/create-branch.PNG diff --git a/argilla/docs/assets/images/community/contributing/create-fork.PNG b/extralit/docs/assets/images/community/contributing/create-fork.PNG similarity index 100% rename from argilla/docs/assets/images/community/contributing/create-fork.PNG rename to extralit/docs/assets/images/community/contributing/create-fork.PNG diff --git a/argilla/docs/assets/images/community/contributing/edit-file.PNG b/extralit/docs/assets/images/community/contributing/edit-file.PNG similarity index 100% rename from argilla/docs/assets/images/community/contributing/edit-file.PNG rename to extralit/docs/assets/images/community/contributing/edit-file.PNG diff --git a/argilla/docs/assets/images/community/contributing/fork-bar.PNG b/extralit/docs/assets/images/community/contributing/fork-bar.PNG similarity index 100% rename from argilla/docs/assets/images/community/contributing/fork-bar.PNG rename to extralit/docs/assets/images/community/contributing/fork-bar.PNG diff --git a/argilla/docs/assets/images/community/contributing/issue-feature-template.PNG b/extralit/docs/assets/images/community/contributing/issue-feature-template.PNG similarity index 100% rename from argilla/docs/assets/images/community/contributing/issue-feature-template.PNG rename to extralit/docs/assets/images/community/contributing/issue-feature-template.PNG diff --git a/argilla/docs/assets/images/community/contributing/issues-page.PNG b/extralit/docs/assets/images/community/contributing/issues-page.PNG similarity index 100% rename from argilla/docs/assets/images/community/contributing/issues-page.PNG rename to extralit/docs/assets/images/community/contributing/issues-page.PNG diff --git a/argilla/docs/assets/images/community/contributing/pull-request-feature.PNG b/extralit/docs/assets/images/community/contributing/pull-request-feature.PNG similarity index 100% rename from argilla/docs/assets/images/community/contributing/pull-request-feature.PNG rename to extralit/docs/assets/images/community/contributing/pull-request-feature.PNG diff --git a/argilla/docs/assets/images/community/contributing/pull-request.PNG b/extralit/docs/assets/images/community/contributing/pull-request.PNG similarity index 100% rename from argilla/docs/assets/images/community/contributing/pull-request.PNG rename to extralit/docs/assets/images/community/contributing/pull-request.PNG diff --git a/argilla/docs/assets/images/community/contributing/templates-issues.PNG b/extralit/docs/assets/images/community/contributing/templates-issues.PNG similarity index 100% rename from argilla/docs/assets/images/community/contributing/templates-issues.PNG rename to extralit/docs/assets/images/community/contributing/templates-issues.PNG diff --git a/argilla/docs/assets/images/community/developer/database_tables.png b/extralit/docs/assets/images/community/developer/database_tables.png similarity index 100% rename from argilla/docs/assets/images/community/developer/database_tables.png rename to extralit/docs/assets/images/community/developer/database_tables.png diff --git a/argilla/docs/assets/images/community/developer/repo-visualizer-argilla-server.svg b/extralit/docs/assets/images/community/developer/repo-visualizer-argilla-server.svg similarity index 100% rename from argilla/docs/assets/images/community/developer/repo-visualizer-argilla-server.svg rename to extralit/docs/assets/images/community/developer/repo-visualizer-argilla-server.svg diff --git a/argilla/docs/assets/images/community/developer/repo-visualizer-argilla.svg b/extralit/docs/assets/images/community/developer/repo-visualizer-argilla.svg similarity index 100% rename from argilla/docs/assets/images/community/developer/repo-visualizer-argilla.svg rename to extralit/docs/assets/images/community/developer/repo-visualizer-argilla.svg diff --git a/argilla/docs/assets/images/community/integrations/llamaindex_rag_1.png b/extralit/docs/assets/images/community/integrations/llamaindex_rag_1.png similarity index 100% rename from argilla/docs/assets/images/community/integrations/llamaindex_rag_1.png rename to extralit/docs/assets/images/community/integrations/llamaindex_rag_1.png diff --git a/argilla/docs/assets/images/docker-mark-blue.svg b/extralit/docs/assets/images/docker-mark-blue.svg similarity index 100% rename from argilla/docs/assets/images/docker-mark-blue.svg rename to extralit/docs/assets/images/docker-mark-blue.svg diff --git a/argilla/docs/assets/images/getting_started/data-extraction-pipeline.jpg b/extralit/docs/assets/images/getting_started/data-extraction-pipeline.jpg similarity index 100% rename from argilla/docs/assets/images/getting_started/data-extraction-pipeline.jpg rename to extralit/docs/assets/images/getting_started/data-extraction-pipeline.jpg diff --git a/argilla/docs/assets/images/getting_started/dataset_configurator.png b/extralit/docs/assets/images/getting_started/dataset_configurator.png similarity index 100% rename from argilla/docs/assets/images/getting_started/dataset_configurator.png rename to extralit/docs/assets/images/getting_started/dataset_configurator.png diff --git a/argilla/docs/assets/images/getting_started/signin-hf-page.png b/extralit/docs/assets/images/getting_started/signin-hf-page.png similarity index 100% rename from argilla/docs/assets/images/getting_started/signin-hf-page.png rename to extralit/docs/assets/images/getting_started/signin-hf-page.png diff --git a/argilla/docs/assets/images/how_to_guides/annotate/bulk_view.png b/extralit/docs/assets/images/how_to_guides/annotate/bulk_view.png similarity index 100% rename from argilla/docs/assets/images/how_to_guides/annotate/bulk_view.png rename to extralit/docs/assets/images/how_to_guides/annotate/bulk_view.png diff --git a/argilla/docs/assets/images/how_to_guides/annotate/focus_view.png b/extralit/docs/assets/images/how_to_guides/annotate/focus_view.png similarity index 100% rename from argilla/docs/assets/images/how_to_guides/annotate/focus_view.png rename to extralit/docs/assets/images/how_to_guides/annotate/focus_view.png diff --git a/argilla/docs/assets/images/how_to_guides/annotate/ui_overview.png b/extralit/docs/assets/images/how_to_guides/annotate/ui_overview.png similarity index 100% rename from argilla/docs/assets/images/how_to_guides/annotate/ui_overview.png rename to extralit/docs/assets/images/how_to_guides/annotate/ui_overview.png diff --git a/argilla/docs/assets/images/how_to_guides/custom_field/3d_object_viewer.png b/extralit/docs/assets/images/how_to_guides/custom_field/3d_object_viewer.png similarity index 100% rename from argilla/docs/assets/images/how_to_guides/custom_field/3d_object_viewer.png rename to extralit/docs/assets/images/how_to_guides/custom_field/3d_object_viewer.png diff --git a/argilla/docs/assets/images/how_to_guides/custom_field/images_in_two_columns.png b/extralit/docs/assets/images/how_to_guides/custom_field/images_in_two_columns.png similarity index 100% rename from argilla/docs/assets/images/how_to_guides/custom_field/images_in_two_columns.png rename to extralit/docs/assets/images/how_to_guides/custom_field/images_in_two_columns.png diff --git a/argilla/docs/assets/images/how_to_guides/custom_field/metadata_table.png b/extralit/docs/assets/images/how_to_guides/custom_field/metadata_table.png similarity index 100% rename from argilla/docs/assets/images/how_to_guides/custom_field/metadata_table.png rename to extralit/docs/assets/images/how_to_guides/custom_field/metadata_table.png diff --git a/argilla/docs/assets/images/how_to_guides/dataset/chat_field.png b/extralit/docs/assets/images/how_to_guides/dataset/chat_field.png similarity index 100% rename from argilla/docs/assets/images/how_to_guides/dataset/chat_field.png rename to extralit/docs/assets/images/how_to_guides/dataset/chat_field.png diff --git a/argilla/docs/assets/images/how_to_guides/dataset/fields.png b/extralit/docs/assets/images/how_to_guides/dataset/fields.png similarity index 100% rename from argilla/docs/assets/images/how_to_guides/dataset/fields.png rename to extralit/docs/assets/images/how_to_guides/dataset/fields.png diff --git a/argilla/docs/assets/images/how_to_guides/dataset/float_metadata.png b/extralit/docs/assets/images/how_to_guides/dataset/float_metadata.png similarity index 100% rename from argilla/docs/assets/images/how_to_guides/dataset/float_metadata.png rename to extralit/docs/assets/images/how_to_guides/dataset/float_metadata.png diff --git a/argilla/docs/assets/images/how_to_guides/dataset/guidelines.png b/extralit/docs/assets/images/how_to_guides/dataset/guidelines.png similarity index 100% rename from argilla/docs/assets/images/how_to_guides/dataset/guidelines.png rename to extralit/docs/assets/images/how_to_guides/dataset/guidelines.png diff --git a/argilla/docs/assets/images/how_to_guides/dataset/guidelines_description.png b/extralit/docs/assets/images/how_to_guides/dataset/guidelines_description.png similarity index 100% rename from argilla/docs/assets/images/how_to_guides/dataset/guidelines_description.png rename to extralit/docs/assets/images/how_to_guides/dataset/guidelines_description.png diff --git a/argilla/docs/assets/images/how_to_guides/dataset/image_field.png b/extralit/docs/assets/images/how_to_guides/dataset/image_field.png similarity index 100% rename from argilla/docs/assets/images/how_to_guides/dataset/image_field.png rename to extralit/docs/assets/images/how_to_guides/dataset/image_field.png diff --git a/argilla/docs/assets/images/how_to_guides/dataset/integer_metadata.png b/extralit/docs/assets/images/how_to_guides/dataset/integer_metadata.png similarity index 100% rename from argilla/docs/assets/images/how_to_guides/dataset/integer_metadata.png rename to extralit/docs/assets/images/how_to_guides/dataset/integer_metadata.png diff --git a/argilla/docs/assets/images/how_to_guides/dataset/label_question.png b/extralit/docs/assets/images/how_to_guides/dataset/label_question.png similarity index 100% rename from argilla/docs/assets/images/how_to_guides/dataset/label_question.png rename to extralit/docs/assets/images/how_to_guides/dataset/label_question.png diff --git a/argilla/docs/assets/images/how_to_guides/dataset/multilabel_question.png b/extralit/docs/assets/images/how_to_guides/dataset/multilabel_question.png similarity index 100% rename from argilla/docs/assets/images/how_to_guides/dataset/multilabel_question.png rename to extralit/docs/assets/images/how_to_guides/dataset/multilabel_question.png diff --git a/argilla/docs/assets/images/how_to_guides/dataset/ranking_question.png b/extralit/docs/assets/images/how_to_guides/dataset/ranking_question.png similarity index 100% rename from argilla/docs/assets/images/how_to_guides/dataset/ranking_question.png rename to extralit/docs/assets/images/how_to_guides/dataset/ranking_question.png diff --git a/argilla/docs/assets/images/how_to_guides/dataset/rating_question.png b/extralit/docs/assets/images/how_to_guides/dataset/rating_question.png similarity index 100% rename from argilla/docs/assets/images/how_to_guides/dataset/rating_question.png rename to extralit/docs/assets/images/how_to_guides/dataset/rating_question.png diff --git a/argilla/docs/assets/images/how_to_guides/dataset/span_question.png b/extralit/docs/assets/images/how_to_guides/dataset/span_question.png similarity index 100% rename from argilla/docs/assets/images/how_to_guides/dataset/span_question.png rename to extralit/docs/assets/images/how_to_guides/dataset/span_question.png diff --git a/argilla/docs/assets/images/how_to_guides/dataset/term_metadata.png b/extralit/docs/assets/images/how_to_guides/dataset/term_metadata.png similarity index 100% rename from argilla/docs/assets/images/how_to_guides/dataset/term_metadata.png rename to extralit/docs/assets/images/how_to_guides/dataset/term_metadata.png diff --git a/argilla/docs/assets/images/how_to_guides/dataset/text_field.png b/extralit/docs/assets/images/how_to_guides/dataset/text_field.png similarity index 100% rename from argilla/docs/assets/images/how_to_guides/dataset/text_field.png rename to extralit/docs/assets/images/how_to_guides/dataset/text_field.png diff --git a/argilla/docs/assets/images/how_to_guides/dataset/text_question.png b/extralit/docs/assets/images/how_to_guides/dataset/text_question.png similarity index 100% rename from argilla/docs/assets/images/how_to_guides/dataset/text_question.png rename to extralit/docs/assets/images/how_to_guides/dataset/text_question.png diff --git a/argilla/docs/assets/images/how_to_guides/dataset/vectors.png b/extralit/docs/assets/images/how_to_guides/dataset/vectors.png similarity index 100% rename from argilla/docs/assets/images/how_to_guides/dataset/vectors.png rename to extralit/docs/assets/images/how_to_guides/dataset/vectors.png diff --git a/argilla/docs/assets/images/how_to_guides/distribution/taskdistribution.svg b/extralit/docs/assets/images/how_to_guides/distribution/taskdistribution.svg similarity index 100% rename from argilla/docs/assets/images/how_to_guides/distribution/taskdistribution.svg rename to extralit/docs/assets/images/how_to_guides/distribution/taskdistribution.svg diff --git a/argilla/docs/assets/images/how_to_guides/markdown/chat.png b/extralit/docs/assets/images/how_to_guides/markdown/chat.png similarity index 100% rename from argilla/docs/assets/images/how_to_guides/markdown/chat.png rename to extralit/docs/assets/images/how_to_guides/markdown/chat.png diff --git a/argilla/docs/assets/images/how_to_guides/markdown/media.png b/extralit/docs/assets/images/how_to_guides/markdown/media.png similarity index 100% rename from argilla/docs/assets/images/how_to_guides/markdown/media.png rename to extralit/docs/assets/images/how_to_guides/markdown/media.png diff --git a/argilla/docs/assets/images/huggingface-mark.svg b/extralit/docs/assets/images/huggingface-mark.svg similarity index 100% rename from argilla/docs/assets/images/huggingface-mark.svg rename to extralit/docs/assets/images/huggingface-mark.svg diff --git a/argilla/docs/assets/logo.png b/extralit/docs/assets/logo.png similarity index 100% rename from argilla/docs/assets/logo.png rename to extralit/docs/assets/logo.png diff --git a/argilla/docs/assets/logo.svg b/extralit/docs/assets/logo.svg similarity index 100% rename from argilla/docs/assets/logo.svg rename to extralit/docs/assets/logo.svg diff --git a/argilla/docs/assets/og-doc.png b/extralit/docs/assets/og-doc.png similarity index 100% rename from argilla/docs/assets/og-doc.png rename to extralit/docs/assets/og-doc.png diff --git a/argilla/docs/community/adding_language.md b/extralit/docs/community/adding_language.md similarity index 100% rename from argilla/docs/community/adding_language.md rename to extralit/docs/community/adding_language.md diff --git a/argilla/docs/community/contributor.md b/extralit/docs/community/contributor.md similarity index 100% rename from argilla/docs/community/contributor.md rename to extralit/docs/community/contributor.md diff --git a/argilla/docs/community/developer.md b/extralit/docs/community/developer.md similarity index 100% rename from argilla/docs/community/developer.md rename to extralit/docs/community/developer.md diff --git a/argilla/docs/community/index.md b/extralit/docs/community/index.md similarity index 100% rename from argilla/docs/community/index.md rename to extralit/docs/community/index.md diff --git a/argilla/docs/community/integrations/llamaindex_rag_github.ipynb b/extralit/docs/community/integrations/llamaindex_rag_github.ipynb similarity index 100% rename from argilla/docs/community/integrations/llamaindex_rag_github.ipynb rename to extralit/docs/community/integrations/llamaindex_rag_github.ipynb diff --git a/argilla/docs/community/release_guide.md b/extralit/docs/community/release_guide.md similarity index 100% rename from argilla/docs/community/release_guide.md rename to extralit/docs/community/release_guide.md diff --git a/argilla/docs/faq.md b/extralit/docs/faq.md similarity index 100% rename from argilla/docs/faq.md rename to extralit/docs/faq.md diff --git a/argilla/docs/getting_started/development_setup.md b/extralit/docs/getting_started/development_setup.md similarity index 98% rename from argilla/docs/getting_started/development_setup.md rename to extralit/docs/getting_started/development_setup.md index c92166071..8afc9bb34 100644 --- a/argilla/docs/getting_started/development_setup.md +++ b/extralit/docs/getting_started/development_setup.md @@ -158,7 +158,7 @@ cd argilla-server pdm install # Install client dependencies -cd ../argilla +cd ../extralit pdm install ``` @@ -180,7 +180,7 @@ Create a `.env.dev` file in the `argilla-server` directory with the following co ``` ALEMBIC_CONFIG=src/argilla_server/alembic.ini ARGILLA_AUTH_SECRET_KEY=8VO7na5N/jQx+yP/N+HlE8q51vPdrxqlh6OzoebIyko= -ARGILLA_DATABASE_URL=sqlite+aiosqlite:///${HOME}/.argilla/argilla-dev.db?check_same_thread=False +ARGILLA_DATABASE_URL=sqlite+aiosqlite:///${HOME}/.extralit/argilla-dev.db?check_same_thread=False # Search engine configuration ARGILLA_SEARCH_ENGINE=elasticsearch ARGILLA_ELASTICSEARCH=http://localhost:9200 @@ -377,7 +377,7 @@ If database migrations fail: ```bash # Reset the database -rm -rf ~/.argilla/argilla-dev.db +rm -rf ~/.extralit/argilla-dev.db pdm run migrate ``` diff --git a/argilla/docs/getting_started/faq.md b/extralit/docs/getting_started/faq.md similarity index 100% rename from argilla/docs/getting_started/faq.md rename to extralit/docs/getting_started/faq.md diff --git a/argilla/docs/getting_started/how-to-configure-argilla-on-huggingface.md b/extralit/docs/getting_started/how-to-configure-argilla-on-huggingface.md similarity index 100% rename from argilla/docs/getting_started/how-to-configure-argilla-on-huggingface.md rename to extralit/docs/getting_started/how-to-configure-argilla-on-huggingface.md diff --git a/argilla/docs/getting_started/how-to-deploy-argilla-with-docker.md b/extralit/docs/getting_started/how-to-deploy-argilla-with-docker.md similarity index 100% rename from argilla/docs/getting_started/how-to-deploy-argilla-with-docker.md rename to extralit/docs/getting_started/how-to-deploy-argilla-with-docker.md diff --git a/argilla/docs/getting_started/installation.md b/extralit/docs/getting_started/installation.md similarity index 100% rename from argilla/docs/getting_started/installation.md rename to extralit/docs/getting_started/installation.md diff --git a/argilla/docs/getting_started/quickstart.md b/extralit/docs/getting_started/quickstart.md similarity index 100% rename from argilla/docs/getting_started/quickstart.md rename to extralit/docs/getting_started/quickstart.md diff --git a/argilla/docs/glossary.md b/extralit/docs/glossary.md similarity index 100% rename from argilla/docs/glossary.md rename to extralit/docs/glossary.md diff --git a/argilla/docs/index.md b/extralit/docs/index.md similarity index 100% rename from argilla/docs/index.md rename to extralit/docs/index.md diff --git a/argilla/docs/reference/argilla-server/configuration.md b/extralit/docs/reference/argilla-server/configuration.md similarity index 100% rename from argilla/docs/reference/argilla-server/configuration.md rename to extralit/docs/reference/argilla-server/configuration.md diff --git a/argilla/docs/reference/argilla-server/oauth2_configuration.md b/extralit/docs/reference/argilla-server/oauth2_configuration.md similarity index 100% rename from argilla/docs/reference/argilla-server/oauth2_configuration.md rename to extralit/docs/reference/argilla-server/oauth2_configuration.md diff --git a/argilla/docs/reference/argilla-server/sso_keycloak.md b/extralit/docs/reference/argilla-server/sso_keycloak.md similarity index 100% rename from argilla/docs/reference/argilla-server/sso_keycloak.md rename to extralit/docs/reference/argilla-server/sso_keycloak.md diff --git a/argilla/docs/reference/argilla-server/telemetry.md b/extralit/docs/reference/argilla-server/telemetry.md similarity index 100% rename from argilla/docs/reference/argilla-server/telemetry.md rename to extralit/docs/reference/argilla-server/telemetry.md diff --git a/argilla/docs/reference/argilla/SUMMARY.md b/extralit/docs/reference/argilla/SUMMARY.md similarity index 100% rename from argilla/docs/reference/argilla/SUMMARY.md rename to extralit/docs/reference/argilla/SUMMARY.md diff --git a/argilla/docs/reference/argilla/client.md b/extralit/docs/reference/argilla/client.md similarity index 100% rename from argilla/docs/reference/argilla/client.md rename to extralit/docs/reference/argilla/client.md diff --git a/argilla/docs/reference/argilla/datasets/dataset_records.md b/extralit/docs/reference/argilla/datasets/dataset_records.md similarity index 100% rename from argilla/docs/reference/argilla/datasets/dataset_records.md rename to extralit/docs/reference/argilla/datasets/dataset_records.md diff --git a/argilla/docs/reference/argilla/datasets/datasets.md b/extralit/docs/reference/argilla/datasets/datasets.md similarity index 100% rename from argilla/docs/reference/argilla/datasets/datasets.md rename to extralit/docs/reference/argilla/datasets/datasets.md diff --git a/argilla/docs/reference/argilla/markdown.md b/extralit/docs/reference/argilla/markdown.md similarity index 100% rename from argilla/docs/reference/argilla/markdown.md rename to extralit/docs/reference/argilla/markdown.md diff --git a/argilla/docs/reference/argilla/records/metadata.md b/extralit/docs/reference/argilla/records/metadata.md similarity index 100% rename from argilla/docs/reference/argilla/records/metadata.md rename to extralit/docs/reference/argilla/records/metadata.md diff --git a/argilla/docs/reference/argilla/records/records.md b/extralit/docs/reference/argilla/records/records.md similarity index 100% rename from argilla/docs/reference/argilla/records/records.md rename to extralit/docs/reference/argilla/records/records.md diff --git a/argilla/docs/reference/argilla/records/responses.md b/extralit/docs/reference/argilla/records/responses.md similarity index 100% rename from argilla/docs/reference/argilla/records/responses.md rename to extralit/docs/reference/argilla/records/responses.md diff --git a/argilla/docs/reference/argilla/records/suggestions.md b/extralit/docs/reference/argilla/records/suggestions.md similarity index 100% rename from argilla/docs/reference/argilla/records/suggestions.md rename to extralit/docs/reference/argilla/records/suggestions.md diff --git a/argilla/docs/reference/argilla/records/vectors.md b/extralit/docs/reference/argilla/records/vectors.md similarity index 100% rename from argilla/docs/reference/argilla/records/vectors.md rename to extralit/docs/reference/argilla/records/vectors.md diff --git a/argilla/docs/reference/argilla/search.md b/extralit/docs/reference/argilla/search.md similarity index 100% rename from argilla/docs/reference/argilla/search.md rename to extralit/docs/reference/argilla/search.md diff --git a/argilla/docs/reference/argilla/settings/fields.md b/extralit/docs/reference/argilla/settings/fields.md similarity index 100% rename from argilla/docs/reference/argilla/settings/fields.md rename to extralit/docs/reference/argilla/settings/fields.md diff --git a/argilla/docs/reference/argilla/settings/metadata_property.md b/extralit/docs/reference/argilla/settings/metadata_property.md similarity index 100% rename from argilla/docs/reference/argilla/settings/metadata_property.md rename to extralit/docs/reference/argilla/settings/metadata_property.md diff --git a/argilla/docs/reference/argilla/settings/questions.md b/extralit/docs/reference/argilla/settings/questions.md similarity index 100% rename from argilla/docs/reference/argilla/settings/questions.md rename to extralit/docs/reference/argilla/settings/questions.md diff --git a/argilla/docs/reference/argilla/settings/settings.md b/extralit/docs/reference/argilla/settings/settings.md similarity index 100% rename from argilla/docs/reference/argilla/settings/settings.md rename to extralit/docs/reference/argilla/settings/settings.md diff --git a/argilla/docs/reference/argilla/settings/task_distribution.md b/extralit/docs/reference/argilla/settings/task_distribution.md similarity index 100% rename from argilla/docs/reference/argilla/settings/task_distribution.md rename to extralit/docs/reference/argilla/settings/task_distribution.md diff --git a/argilla/docs/reference/argilla/settings/vectors.md b/extralit/docs/reference/argilla/settings/vectors.md similarity index 100% rename from argilla/docs/reference/argilla/settings/vectors.md rename to extralit/docs/reference/argilla/settings/vectors.md diff --git a/argilla/docs/reference/argilla/users.md b/extralit/docs/reference/argilla/users.md similarity index 100% rename from argilla/docs/reference/argilla/users.md rename to extralit/docs/reference/argilla/users.md diff --git a/argilla/docs/reference/argilla/webhooks.md b/extralit/docs/reference/argilla/webhooks.md similarity index 100% rename from argilla/docs/reference/argilla/webhooks.md rename to extralit/docs/reference/argilla/webhooks.md diff --git a/argilla/docs/reference/argilla/workspaces.md b/extralit/docs/reference/argilla/workspaces.md similarity index 100% rename from argilla/docs/reference/argilla/workspaces.md rename to extralit/docs/reference/argilla/workspaces.md diff --git a/argilla/docs/scripts/gen_changelog.py b/extralit/docs/scripts/gen_changelog.py similarity index 93% rename from argilla/docs/scripts/gen_changelog.py rename to extralit/docs/scripts/gen_changelog.py index 3c61ea73b..ca5a649a5 100644 --- a/argilla/docs/scripts/gen_changelog.py +++ b/extralit/docs/scripts/gen_changelog.py @@ -1,4 +1,4 @@ -# Copyright 2024-present, Argilla, Inc. +# Copyright 2024-present, Extralit Labs, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,8 +18,8 @@ import mkdocs_gen_files import requests -REPOSITORY = "extralit/extralit" -CHANGELOG_PATH = "argilla/CHANGELOG.md" +REPOSITORY = "Extralit/extralit" +CHANGELOG_PATH = "extralit/CHANGELOG.md" RETRIEVED_BRANCH = "develop" DATA_PATH = "community/changelog.md" diff --git a/argilla/docs/scripts/gen_popular_issues.py b/extralit/docs/scripts/gen_popular_issues.py similarity index 100% rename from argilla/docs/scripts/gen_popular_issues.py rename to extralit/docs/scripts/gen_popular_issues.py diff --git a/argilla/docs/scripts/gen_ref_pages.py b/extralit/docs/scripts/gen_ref_pages.py similarity index 100% rename from argilla/docs/scripts/gen_ref_pages.py rename to extralit/docs/scripts/gen_ref_pages.py diff --git a/argilla/docs/stylesheets/extra.css b/extralit/docs/stylesheets/extra.css similarity index 100% rename from argilla/docs/stylesheets/extra.css rename to extralit/docs/stylesheets/extra.css diff --git a/argilla/docs/stylesheets/fonts/FontAwesome.otf b/extralit/docs/stylesheets/fonts/FontAwesome.otf similarity index 100% rename from argilla/docs/stylesheets/fonts/FontAwesome.otf rename to extralit/docs/stylesheets/fonts/FontAwesome.otf diff --git a/argilla/docs/stylesheets/fonts/fontawesome-webfont.eot b/extralit/docs/stylesheets/fonts/fontawesome-webfont.eot similarity index 100% rename from argilla/docs/stylesheets/fonts/fontawesome-webfont.eot rename to extralit/docs/stylesheets/fonts/fontawesome-webfont.eot diff --git a/argilla/docs/stylesheets/fonts/fontawesome-webfont.svg b/extralit/docs/stylesheets/fonts/fontawesome-webfont.svg similarity index 100% rename from argilla/docs/stylesheets/fonts/fontawesome-webfont.svg rename to extralit/docs/stylesheets/fonts/fontawesome-webfont.svg diff --git a/argilla/docs/stylesheets/fonts/fontawesome-webfont.ttf b/extralit/docs/stylesheets/fonts/fontawesome-webfont.ttf similarity index 100% rename from argilla/docs/stylesheets/fonts/fontawesome-webfont.ttf rename to extralit/docs/stylesheets/fonts/fontawesome-webfont.ttf diff --git a/argilla/docs/stylesheets/fonts/fontawesome-webfont.woff b/extralit/docs/stylesheets/fonts/fontawesome-webfont.woff similarity index 100% rename from argilla/docs/stylesheets/fonts/fontawesome-webfont.woff rename to extralit/docs/stylesheets/fonts/fontawesome-webfont.woff diff --git a/argilla/docs/stylesheets/fonts/fontawesome-webfont.woff2 b/extralit/docs/stylesheets/fonts/fontawesome-webfont.woff2 similarity index 100% rename from argilla/docs/stylesheets/fonts/fontawesome-webfont.woff2 rename to extralit/docs/stylesheets/fonts/fontawesome-webfont.woff2 diff --git a/argilla/docs/stylesheets/old.css b/extralit/docs/stylesheets/old.css similarity index 100% rename from argilla/docs/stylesheets/old.css rename to extralit/docs/stylesheets/old.css diff --git a/argilla/docs/tutorials/image_classification.ipynb b/extralit/docs/tutorials/image_classification.ipynb similarity index 100% rename from argilla/docs/tutorials/image_classification.ipynb rename to extralit/docs/tutorials/image_classification.ipynb diff --git a/argilla/docs/tutorials/image_preference.ipynb b/extralit/docs/tutorials/image_preference.ipynb similarity index 100% rename from argilla/docs/tutorials/image_preference.ipynb rename to extralit/docs/tutorials/image_preference.ipynb diff --git a/argilla/docs/tutorials/index.md b/extralit/docs/tutorials/index.md similarity index 100% rename from argilla/docs/tutorials/index.md rename to extralit/docs/tutorials/index.md diff --git a/argilla/docs/tutorials/text_classification.ipynb b/extralit/docs/tutorials/text_classification.ipynb similarity index 100% rename from argilla/docs/tutorials/text_classification.ipynb rename to extralit/docs/tutorials/text_classification.ipynb diff --git a/argilla/docs/tutorials/token_classification.ipynb b/extralit/docs/tutorials/token_classification.ipynb similarity index 100% rename from argilla/docs/tutorials/token_classification.ipynb rename to extralit/docs/tutorials/token_classification.ipynb diff --git a/argilla/docs/user_guide/command_line_interface.md b/extralit/docs/user_guide/command_line_interface.md similarity index 100% rename from argilla/docs/user_guide/command_line_interface.md rename to extralit/docs/user_guide/command_line_interface.md diff --git a/argilla/docs/user_guide/core_concepts.md b/extralit/docs/user_guide/core_concepts.md similarity index 100% rename from argilla/docs/user_guide/core_concepts.md rename to extralit/docs/user_guide/core_concepts.md diff --git a/argilla/docs/user_guide/documents_import.md b/extralit/docs/user_guide/documents_import.md similarity index 100% rename from argilla/docs/user_guide/documents_import.md rename to extralit/docs/user_guide/documents_import.md diff --git a/argilla/docs/user_guide/index.md b/extralit/docs/user_guide/index.md similarity index 100% rename from argilla/docs/user_guide/index.md rename to extralit/docs/user_guide/index.md diff --git a/argilla/docs/user_guide/multi_schemas.md b/extralit/docs/user_guide/multi_schemas.md similarity index 100% rename from argilla/docs/user_guide/multi_schemas.md rename to extralit/docs/user_guide/multi_schemas.md diff --git a/argilla/docs/user_guide/overview.md b/extralit/docs/user_guide/overview.md similarity index 100% rename from argilla/docs/user_guide/overview.md rename to extralit/docs/user_guide/overview.md diff --git a/argilla/docs/user_guide/schema_definition.md b/extralit/docs/user_guide/schema_definition.md similarity index 100% rename from argilla/docs/user_guide/schema_definition.md rename to extralit/docs/user_guide/schema_definition.md diff --git a/argilla/license_header.txt b/extralit/license_header.txt similarity index 100% rename from argilla/license_header.txt rename to extralit/license_header.txt diff --git a/argilla/mkdocs.yml b/extralit/mkdocs.yml similarity index 99% rename from argilla/mkdocs.yml rename to extralit/mkdocs.yml index 521a99571..7c7a5706d 100644 --- a/argilla/mkdocs.yml +++ b/extralit/mkdocs.yml @@ -8,7 +8,7 @@ copyright: Copyright © 2023 - 2025 Extralit # Repository repo_name: extralit/extralit repo_url: https://github.com/extralit/extralit/ -edit_uri: edit/main/argilla/docs/ +edit_uri: edit/main/extralit/docs/ extra: version: diff --git a/argilla/pdm.lock b/extralit/pdm.lock similarity index 100% rename from argilla/pdm.lock rename to extralit/pdm.lock diff --git a/argilla/pyproject.toml b/extralit/pyproject.toml similarity index 96% rename from argilla/pyproject.toml rename to extralit/pyproject.toml index 459d5925f..afa35d2cd 100644 --- a/argilla/pyproject.toml +++ b/extralit/pyproject.toml @@ -25,11 +25,11 @@ dependencies = [ "typer>=0.9.0", # for environment variables - "python-dotenv", + "python-dotenv~=1.1.0", # for extralit "minio ~= 7.2.15", - "html5lib", + "html5lib ~= 1.1", "fastapi < 1.0.0", "pypandoc ~= 1.13", "beautifulsoup4 ~= 4.12.2", @@ -48,8 +48,8 @@ dependencies = [ "json-repair ~= 0.19.2", "fastparquet >= 2023.10.0; python_version < '3.13'", "fastparquet >= 2024.4.0; python_version >= '3.13'", - "tiktoken", - "pymupdf", + "tiktoken ~= 0.9.0", + "pymupdf==1.26.0", # for llama-index "llama-index ~= 0.10.68", @@ -78,9 +78,9 @@ pdf = [ "llmsherpa ~= 0.1.3", "python-doctr ~= 0.8.1", "deepdoctection", - "pypdf", + "pypdf ~= 4.3.1", "pypdfium2", - "pymupdf", + "pymupdf==1.26.0", "pdf2image ~= 1.16.0", ] legacy = ["argilla-v1[listeners]"] diff --git a/argilla/src/argilla/__init__.py b/extralit/src/argilla/__init__.py similarity index 100% rename from argilla/src/argilla/__init__.py rename to extralit/src/argilla/__init__.py diff --git a/argilla/src/argilla/_api/__init__.py b/extralit/src/argilla/_api/__init__.py similarity index 100% rename from argilla/src/argilla/_api/__init__.py rename to extralit/src/argilla/_api/__init__.py diff --git a/argilla/src/argilla/_api/_base.py b/extralit/src/argilla/_api/_base.py similarity index 100% rename from argilla/src/argilla/_api/_base.py rename to extralit/src/argilla/_api/_base.py diff --git a/argilla/src/argilla/_api/_client.py b/extralit/src/argilla/_api/_client.py similarity index 100% rename from argilla/src/argilla/_api/_client.py rename to extralit/src/argilla/_api/_client.py diff --git a/argilla/src/argilla/_api/_datasets.py b/extralit/src/argilla/_api/_datasets.py similarity index 100% rename from argilla/src/argilla/_api/_datasets.py rename to extralit/src/argilla/_api/_datasets.py diff --git a/argilla/src/argilla/_api/_fields.py b/extralit/src/argilla/_api/_fields.py similarity index 100% rename from argilla/src/argilla/_api/_fields.py rename to extralit/src/argilla/_api/_fields.py diff --git a/argilla/src/argilla/_api/_http/__init__.py b/extralit/src/argilla/_api/_http/__init__.py similarity index 100% rename from argilla/src/argilla/_api/_http/__init__.py rename to extralit/src/argilla/_api/_http/__init__.py diff --git a/argilla/src/argilla/_api/_http/_client.py b/extralit/src/argilla/_api/_http/_client.py similarity index 100% rename from argilla/src/argilla/_api/_http/_client.py rename to extralit/src/argilla/_api/_http/_client.py diff --git a/argilla/src/argilla/_api/_http/_helpers.py b/extralit/src/argilla/_api/_http/_helpers.py similarity index 100% rename from argilla/src/argilla/_api/_http/_helpers.py rename to extralit/src/argilla/_api/_http/_helpers.py diff --git a/argilla/src/argilla/_api/_metadata.py b/extralit/src/argilla/_api/_metadata.py similarity index 100% rename from argilla/src/argilla/_api/_metadata.py rename to extralit/src/argilla/_api/_metadata.py diff --git a/argilla/src/argilla/_api/_questions.py b/extralit/src/argilla/_api/_questions.py similarity index 100% rename from argilla/src/argilla/_api/_questions.py rename to extralit/src/argilla/_api/_questions.py diff --git a/argilla/src/argilla/_api/_records.py b/extralit/src/argilla/_api/_records.py similarity index 100% rename from argilla/src/argilla/_api/_records.py rename to extralit/src/argilla/_api/_records.py diff --git a/argilla/src/argilla/_api/_token.py b/extralit/src/argilla/_api/_token.py similarity index 100% rename from argilla/src/argilla/_api/_token.py rename to extralit/src/argilla/_api/_token.py diff --git a/argilla/src/argilla/_api/_users.py b/extralit/src/argilla/_api/_users.py similarity index 100% rename from argilla/src/argilla/_api/_users.py rename to extralit/src/argilla/_api/_users.py diff --git a/argilla/src/argilla/_api/_vectors.py b/extralit/src/argilla/_api/_vectors.py similarity index 100% rename from argilla/src/argilla/_api/_vectors.py rename to extralit/src/argilla/_api/_vectors.py diff --git a/argilla/src/argilla/_api/_webhooks.py b/extralit/src/argilla/_api/_webhooks.py similarity index 100% rename from argilla/src/argilla/_api/_webhooks.py rename to extralit/src/argilla/_api/_webhooks.py diff --git a/argilla/src/argilla/_api/_workspaces.py b/extralit/src/argilla/_api/_workspaces.py similarity index 100% rename from argilla/src/argilla/_api/_workspaces.py rename to extralit/src/argilla/_api/_workspaces.py diff --git a/argilla/src/argilla/_constants.py b/extralit/src/argilla/_constants.py similarity index 100% rename from argilla/src/argilla/_constants.py rename to extralit/src/argilla/_constants.py diff --git a/argilla/src/argilla/_exceptions/__init__.py b/extralit/src/argilla/_exceptions/__init__.py similarity index 100% rename from argilla/src/argilla/_exceptions/__init__.py rename to extralit/src/argilla/_exceptions/__init__.py diff --git a/argilla/src/argilla/_exceptions/_api.py b/extralit/src/argilla/_exceptions/_api.py similarity index 100% rename from argilla/src/argilla/_exceptions/_api.py rename to extralit/src/argilla/_exceptions/_api.py diff --git a/argilla/src/argilla/_exceptions/_base.py b/extralit/src/argilla/_exceptions/_base.py similarity index 100% rename from argilla/src/argilla/_exceptions/_base.py rename to extralit/src/argilla/_exceptions/_base.py diff --git a/argilla/src/argilla/_exceptions/_client.py b/extralit/src/argilla/_exceptions/_client.py similarity index 100% rename from argilla/src/argilla/_exceptions/_client.py rename to extralit/src/argilla/_exceptions/_client.py diff --git a/argilla/src/argilla/_exceptions/_hub.py b/extralit/src/argilla/_exceptions/_hub.py similarity index 100% rename from argilla/src/argilla/_exceptions/_hub.py rename to extralit/src/argilla/_exceptions/_hub.py diff --git a/argilla/src/argilla/_exceptions/_metadata.py b/extralit/src/argilla/_exceptions/_metadata.py similarity index 100% rename from argilla/src/argilla/_exceptions/_metadata.py rename to extralit/src/argilla/_exceptions/_metadata.py diff --git a/argilla/src/argilla/_exceptions/_records.py b/extralit/src/argilla/_exceptions/_records.py similarity index 100% rename from argilla/src/argilla/_exceptions/_records.py rename to extralit/src/argilla/_exceptions/_records.py diff --git a/argilla/src/argilla/_exceptions/_responses.py b/extralit/src/argilla/_exceptions/_responses.py similarity index 100% rename from argilla/src/argilla/_exceptions/_responses.py rename to extralit/src/argilla/_exceptions/_responses.py diff --git a/argilla/src/argilla/_exceptions/_serialization.py b/extralit/src/argilla/_exceptions/_serialization.py similarity index 100% rename from argilla/src/argilla/_exceptions/_serialization.py rename to extralit/src/argilla/_exceptions/_serialization.py diff --git a/argilla/src/argilla/_exceptions/_settings.py b/extralit/src/argilla/_exceptions/_settings.py similarity index 100% rename from argilla/src/argilla/_exceptions/_settings.py rename to extralit/src/argilla/_exceptions/_settings.py diff --git a/argilla/src/argilla/_exceptions/_suggestions.py b/extralit/src/argilla/_exceptions/_suggestions.py similarity index 100% rename from argilla/src/argilla/_exceptions/_suggestions.py rename to extralit/src/argilla/_exceptions/_suggestions.py diff --git a/argilla/src/argilla/_helpers/__init__.py b/extralit/src/argilla/_helpers/__init__.py similarity index 100% rename from argilla/src/argilla/_helpers/__init__.py rename to extralit/src/argilla/_helpers/__init__.py diff --git a/argilla/src/argilla/_helpers/_dataclasses.py b/extralit/src/argilla/_helpers/_dataclasses.py similarity index 100% rename from argilla/src/argilla/_helpers/_dataclasses.py rename to extralit/src/argilla/_helpers/_dataclasses.py diff --git a/argilla/src/argilla/_helpers/_deploy.py b/extralit/src/argilla/_helpers/_deploy.py similarity index 99% rename from argilla/src/argilla/_helpers/_deploy.py rename to extralit/src/argilla/_helpers/_deploy.py index c676bc590..c007c40a7 100644 --- a/argilla/src/argilla/_helpers/_deploy.py +++ b/extralit/src/argilla/_helpers/_deploy.py @@ -1,4 +1,4 @@ -# Copyright 2024-present, Argilla, Inc. +# Copyright 2024-present, Extralit Labs, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/argilla/src/argilla/_helpers/_iterator.py b/extralit/src/argilla/_helpers/_iterator.py similarity index 100% rename from argilla/src/argilla/_helpers/_iterator.py rename to extralit/src/argilla/_helpers/_iterator.py diff --git a/argilla/src/argilla/_helpers/_log.py b/extralit/src/argilla/_helpers/_log.py similarity index 100% rename from argilla/src/argilla/_helpers/_log.py rename to extralit/src/argilla/_helpers/_log.py diff --git a/argilla/src/argilla/_helpers/_media.py b/extralit/src/argilla/_helpers/_media.py similarity index 100% rename from argilla/src/argilla/_helpers/_media.py rename to extralit/src/argilla/_helpers/_media.py diff --git a/argilla/src/argilla/_helpers/_resource_repr.py b/extralit/src/argilla/_helpers/_resource_repr.py similarity index 100% rename from argilla/src/argilla/_helpers/_resource_repr.py rename to extralit/src/argilla/_helpers/_resource_repr.py diff --git a/argilla/src/argilla/_helpers/_uuid.py b/extralit/src/argilla/_helpers/_uuid.py similarity index 100% rename from argilla/src/argilla/_helpers/_uuid.py rename to extralit/src/argilla/_helpers/_uuid.py diff --git a/argilla/src/argilla/_models/__init__.py b/extralit/src/argilla/_models/__init__.py similarity index 100% rename from argilla/src/argilla/_models/__init__.py rename to extralit/src/argilla/_models/__init__.py diff --git a/argilla/src/argilla/_models/_base.py b/extralit/src/argilla/_models/_base.py similarity index 100% rename from argilla/src/argilla/_models/_base.py rename to extralit/src/argilla/_models/_base.py diff --git a/argilla/src/argilla/_models/_dataset.py b/extralit/src/argilla/_models/_dataset.py similarity index 100% rename from argilla/src/argilla/_models/_dataset.py rename to extralit/src/argilla/_models/_dataset.py diff --git a/argilla/src/argilla/_models/_dataset_progress.py b/extralit/src/argilla/_models/_dataset_progress.py similarity index 100% rename from argilla/src/argilla/_models/_dataset_progress.py rename to extralit/src/argilla/_models/_dataset_progress.py diff --git a/argilla/src/argilla/_models/_documents.py b/extralit/src/argilla/_models/_documents.py similarity index 100% rename from argilla/src/argilla/_models/_documents.py rename to extralit/src/argilla/_models/_documents.py diff --git a/argilla/src/argilla/_models/_files.py b/extralit/src/argilla/_models/_files.py similarity index 100% rename from argilla/src/argilla/_models/_files.py rename to extralit/src/argilla/_models/_files.py diff --git a/argilla/src/argilla/_models/_record/__init__.py b/extralit/src/argilla/_models/_record/__init__.py similarity index 100% rename from argilla/src/argilla/_models/_record/__init__.py rename to extralit/src/argilla/_models/_record/__init__.py diff --git a/argilla/src/argilla/_models/_record/_metadata.py b/extralit/src/argilla/_models/_record/_metadata.py similarity index 100% rename from argilla/src/argilla/_models/_record/_metadata.py rename to extralit/src/argilla/_models/_record/_metadata.py diff --git a/argilla/src/argilla/_models/_record/_record.py b/extralit/src/argilla/_models/_record/_record.py similarity index 100% rename from argilla/src/argilla/_models/_record/_record.py rename to extralit/src/argilla/_models/_record/_record.py diff --git a/argilla/src/argilla/_models/_record/_response.py b/extralit/src/argilla/_models/_record/_response.py similarity index 100% rename from argilla/src/argilla/_models/_record/_response.py rename to extralit/src/argilla/_models/_record/_response.py diff --git a/argilla/src/argilla/_models/_record/_suggestion.py b/extralit/src/argilla/_models/_record/_suggestion.py similarity index 100% rename from argilla/src/argilla/_models/_record/_suggestion.py rename to extralit/src/argilla/_models/_record/_suggestion.py diff --git a/argilla/src/argilla/_models/_record/_vector.py b/extralit/src/argilla/_models/_record/_vector.py similarity index 100% rename from argilla/src/argilla/_models/_record/_vector.py rename to extralit/src/argilla/_models/_record/_vector.py diff --git a/argilla/src/argilla/_models/_resource.py b/extralit/src/argilla/_models/_resource.py similarity index 100% rename from argilla/src/argilla/_models/_resource.py rename to extralit/src/argilla/_models/_resource.py diff --git a/argilla/src/argilla/_models/_search.py b/extralit/src/argilla/_models/_search.py similarity index 100% rename from argilla/src/argilla/_models/_search.py rename to extralit/src/argilla/_models/_search.py diff --git a/argilla/src/argilla/_models/_settings/__init__.py b/extralit/src/argilla/_models/_settings/__init__.py similarity index 100% rename from argilla/src/argilla/_models/_settings/__init__.py rename to extralit/src/argilla/_models/_settings/__init__.py diff --git a/argilla/src/argilla/_models/_settings/_fields.py b/extralit/src/argilla/_models/_settings/_fields.py similarity index 100% rename from argilla/src/argilla/_models/_settings/_fields.py rename to extralit/src/argilla/_models/_settings/_fields.py diff --git a/argilla/src/argilla/_models/_settings/_metadata.py b/extralit/src/argilla/_models/_settings/_metadata.py similarity index 100% rename from argilla/src/argilla/_models/_settings/_metadata.py rename to extralit/src/argilla/_models/_settings/_metadata.py diff --git a/argilla/src/argilla/_models/_settings/_questions.py b/extralit/src/argilla/_models/_settings/_questions.py similarity index 100% rename from argilla/src/argilla/_models/_settings/_questions.py rename to extralit/src/argilla/_models/_settings/_questions.py diff --git a/argilla/src/argilla/_models/_settings/_task_distribution.py b/extralit/src/argilla/_models/_settings/_task_distribution.py similarity index 100% rename from argilla/src/argilla/_models/_settings/_task_distribution.py rename to extralit/src/argilla/_models/_settings/_task_distribution.py diff --git a/argilla/src/argilla/_models/_settings/_vectors.py b/extralit/src/argilla/_models/_settings/_vectors.py similarity index 100% rename from argilla/src/argilla/_models/_settings/_vectors.py rename to extralit/src/argilla/_models/_settings/_vectors.py diff --git a/argilla/src/argilla/_models/_user.py b/extralit/src/argilla/_models/_user.py similarity index 100% rename from argilla/src/argilla/_models/_user.py rename to extralit/src/argilla/_models/_user.py diff --git a/argilla/src/argilla/_models/_webhook.py b/extralit/src/argilla/_models/_webhook.py similarity index 100% rename from argilla/src/argilla/_models/_webhook.py rename to extralit/src/argilla/_models/_webhook.py diff --git a/argilla/src/argilla/_models/_workspace.py b/extralit/src/argilla/_models/_workspace.py similarity index 100% rename from argilla/src/argilla/_models/_workspace.py rename to extralit/src/argilla/_models/_workspace.py diff --git a/argilla/src/argilla/_resource.py b/extralit/src/argilla/_resource.py similarity index 100% rename from argilla/src/argilla/_resource.py rename to extralit/src/argilla/_resource.py diff --git a/argilla/src/argilla/_version.py b/extralit/src/argilla/_version.py similarity index 100% rename from argilla/src/argilla/_version.py rename to extralit/src/argilla/_version.py diff --git a/argilla/src/argilla/cli/__init__.py b/extralit/src/argilla/cli/__init__.py similarity index 100% rename from argilla/src/argilla/cli/__init__.py rename to extralit/src/argilla/cli/__init__.py diff --git a/argilla/src/argilla/cli/app.py b/extralit/src/argilla/cli/app.py similarity index 100% rename from argilla/src/argilla/cli/app.py rename to extralit/src/argilla/cli/app.py diff --git a/argilla/src/argilla/cli/callback.py b/extralit/src/argilla/cli/callback.py similarity index 100% rename from argilla/src/argilla/cli/callback.py rename to extralit/src/argilla/cli/callback.py diff --git a/argilla/src/argilla/cli/datasets/__init__.py b/extralit/src/argilla/cli/datasets/__init__.py similarity index 100% rename from argilla/src/argilla/cli/datasets/__init__.py rename to extralit/src/argilla/cli/datasets/__init__.py diff --git a/argilla/src/argilla/cli/datasets/__main__.py b/extralit/src/argilla/cli/datasets/__main__.py similarity index 100% rename from argilla/src/argilla/cli/datasets/__main__.py rename to extralit/src/argilla/cli/datasets/__main__.py diff --git a/argilla/src/argilla/cli/documents/__init__.py b/extralit/src/argilla/cli/documents/__init__.py similarity index 100% rename from argilla/src/argilla/cli/documents/__init__.py rename to extralit/src/argilla/cli/documents/__init__.py diff --git a/argilla/src/argilla/cli/documents/__main__.py b/extralit/src/argilla/cli/documents/__main__.py similarity index 100% rename from argilla/src/argilla/cli/documents/__main__.py rename to extralit/src/argilla/cli/documents/__main__.py diff --git a/argilla/src/argilla/cli/documents/add.py b/extralit/src/argilla/cli/documents/add.py similarity index 100% rename from argilla/src/argilla/cli/documents/add.py rename to extralit/src/argilla/cli/documents/add.py diff --git a/argilla/src/argilla/cli/documents/delete.py b/extralit/src/argilla/cli/documents/delete.py similarity index 100% rename from argilla/src/argilla/cli/documents/delete.py rename to extralit/src/argilla/cli/documents/delete.py diff --git a/argilla/src/argilla/cli/documents/list.py b/extralit/src/argilla/cli/documents/list.py similarity index 100% rename from argilla/src/argilla/cli/documents/list.py rename to extralit/src/argilla/cli/documents/list.py diff --git a/argilla/src/argilla/cli/extraction/__init__.py b/extralit/src/argilla/cli/extraction/__init__.py similarity index 100% rename from argilla/src/argilla/cli/extraction/__init__.py rename to extralit/src/argilla/cli/extraction/__init__.py diff --git a/argilla/src/argilla/cli/extraction/__main__.py b/extralit/src/argilla/cli/extraction/__main__.py similarity index 100% rename from argilla/src/argilla/cli/extraction/__main__.py rename to extralit/src/argilla/cli/extraction/__main__.py diff --git a/argilla/src/argilla/cli/extraction/export.py b/extralit/src/argilla/cli/extraction/export.py similarity index 100% rename from argilla/src/argilla/cli/extraction/export.py rename to extralit/src/argilla/cli/extraction/export.py diff --git a/argilla/src/argilla/cli/extraction/status.py b/extralit/src/argilla/cli/extraction/status.py similarity index 100% rename from argilla/src/argilla/cli/extraction/status.py rename to extralit/src/argilla/cli/extraction/status.py diff --git a/argilla/src/argilla/cli/files/__init__.py b/extralit/src/argilla/cli/files/__init__.py similarity index 100% rename from argilla/src/argilla/cli/files/__init__.py rename to extralit/src/argilla/cli/files/__init__.py diff --git a/argilla/src/argilla/cli/files/__main__.py b/extralit/src/argilla/cli/files/__main__.py similarity index 100% rename from argilla/src/argilla/cli/files/__main__.py rename to extralit/src/argilla/cli/files/__main__.py diff --git a/argilla/src/argilla/cli/files/delete.py b/extralit/src/argilla/cli/files/delete.py similarity index 100% rename from argilla/src/argilla/cli/files/delete.py rename to extralit/src/argilla/cli/files/delete.py diff --git a/argilla/src/argilla/cli/files/download.py b/extralit/src/argilla/cli/files/download.py similarity index 100% rename from argilla/src/argilla/cli/files/download.py rename to extralit/src/argilla/cli/files/download.py diff --git a/argilla/src/argilla/cli/files/list.py b/extralit/src/argilla/cli/files/list.py similarity index 100% rename from argilla/src/argilla/cli/files/list.py rename to extralit/src/argilla/cli/files/list.py diff --git a/argilla/src/argilla/cli/files/upload.py b/extralit/src/argilla/cli/files/upload.py similarity index 100% rename from argilla/src/argilla/cli/files/upload.py rename to extralit/src/argilla/cli/files/upload.py diff --git a/argilla/src/argilla/cli/info/__init__.py b/extralit/src/argilla/cli/info/__init__.py similarity index 100% rename from argilla/src/argilla/cli/info/__init__.py rename to extralit/src/argilla/cli/info/__init__.py diff --git a/argilla/src/argilla/cli/info/__main__.py b/extralit/src/argilla/cli/info/__main__.py similarity index 100% rename from argilla/src/argilla/cli/info/__main__.py rename to extralit/src/argilla/cli/info/__main__.py diff --git a/argilla/src/argilla/cli/login/__init__.py b/extralit/src/argilla/cli/login/__init__.py similarity index 100% rename from argilla/src/argilla/cli/login/__init__.py rename to extralit/src/argilla/cli/login/__init__.py diff --git a/argilla/src/argilla/cli/login/__main__.py b/extralit/src/argilla/cli/login/__main__.py similarity index 100% rename from argilla/src/argilla/cli/login/__main__.py rename to extralit/src/argilla/cli/login/__main__.py diff --git a/argilla/src/argilla/cli/logout/__init__.py b/extralit/src/argilla/cli/logout/__init__.py similarity index 100% rename from argilla/src/argilla/cli/logout/__init__.py rename to extralit/src/argilla/cli/logout/__init__.py diff --git a/argilla/src/argilla/cli/logout/__main__.py b/extralit/src/argilla/cli/logout/__main__.py similarity index 100% rename from argilla/src/argilla/cli/logout/__main__.py rename to extralit/src/argilla/cli/logout/__main__.py diff --git a/argilla/src/argilla/cli/rich.py b/extralit/src/argilla/cli/rich.py similarity index 100% rename from argilla/src/argilla/cli/rich.py rename to extralit/src/argilla/cli/rich.py diff --git a/argilla/src/argilla/cli/schemas/__init__.py b/extralit/src/argilla/cli/schemas/__init__.py similarity index 100% rename from argilla/src/argilla/cli/schemas/__init__.py rename to extralit/src/argilla/cli/schemas/__init__.py diff --git a/argilla/src/argilla/cli/schemas/__main__.py b/extralit/src/argilla/cli/schemas/__main__.py similarity index 100% rename from argilla/src/argilla/cli/schemas/__main__.py rename to extralit/src/argilla/cli/schemas/__main__.py diff --git a/argilla/src/argilla/cli/schemas/download.py b/extralit/src/argilla/cli/schemas/download.py similarity index 100% rename from argilla/src/argilla/cli/schemas/download.py rename to extralit/src/argilla/cli/schemas/download.py diff --git a/argilla/src/argilla/cli/schemas/upload.py b/extralit/src/argilla/cli/schemas/upload.py similarity index 100% rename from argilla/src/argilla/cli/schemas/upload.py rename to extralit/src/argilla/cli/schemas/upload.py diff --git a/argilla/src/argilla/cli/training/__init__.py b/extralit/src/argilla/cli/training/__init__.py similarity index 100% rename from argilla/src/argilla/cli/training/__init__.py rename to extralit/src/argilla/cli/training/__init__.py diff --git a/argilla/src/argilla/cli/training/__main__.py b/extralit/src/argilla/cli/training/__main__.py similarity index 100% rename from argilla/src/argilla/cli/training/__main__.py rename to extralit/src/argilla/cli/training/__main__.py diff --git a/argilla/src/argilla/cli/typer_ext.py b/extralit/src/argilla/cli/typer_ext.py similarity index 100% rename from argilla/src/argilla/cli/typer_ext.py rename to extralit/src/argilla/cli/typer_ext.py diff --git a/argilla/src/argilla/cli/users/__init__.py b/extralit/src/argilla/cli/users/__init__.py similarity index 100% rename from argilla/src/argilla/cli/users/__init__.py rename to extralit/src/argilla/cli/users/__init__.py diff --git a/argilla/src/argilla/cli/users/__main__.py b/extralit/src/argilla/cli/users/__main__.py similarity index 100% rename from argilla/src/argilla/cli/users/__main__.py rename to extralit/src/argilla/cli/users/__main__.py diff --git a/argilla/src/argilla/cli/whoami/__init__.py b/extralit/src/argilla/cli/whoami/__init__.py similarity index 100% rename from argilla/src/argilla/cli/whoami/__init__.py rename to extralit/src/argilla/cli/whoami/__init__.py diff --git a/argilla/src/argilla/cli/whoami/__main__.py b/extralit/src/argilla/cli/whoami/__main__.py similarity index 100% rename from argilla/src/argilla/cli/whoami/__main__.py rename to extralit/src/argilla/cli/whoami/__main__.py diff --git a/argilla/src/argilla/cli/workspaces/__init__.py b/extralit/src/argilla/cli/workspaces/__init__.py similarity index 100% rename from argilla/src/argilla/cli/workspaces/__init__.py rename to extralit/src/argilla/cli/workspaces/__init__.py diff --git a/argilla/src/argilla/cli/workspaces/__main__.py b/extralit/src/argilla/cli/workspaces/__main__.py similarity index 100% rename from argilla/src/argilla/cli/workspaces/__main__.py rename to extralit/src/argilla/cli/workspaces/__main__.py diff --git a/argilla/src/argilla/client/__init__.py b/extralit/src/argilla/client/__init__.py similarity index 100% rename from argilla/src/argilla/client/__init__.py rename to extralit/src/argilla/client/__init__.py diff --git a/argilla/src/argilla/client/core.py b/extralit/src/argilla/client/core.py similarity index 100% rename from argilla/src/argilla/client/core.py rename to extralit/src/argilla/client/core.py diff --git a/argilla/src/argilla/client/login.py b/extralit/src/argilla/client/login.py similarity index 100% rename from argilla/src/argilla/client/login.py rename to extralit/src/argilla/client/login.py diff --git a/argilla/src/argilla/client/resources.py b/extralit/src/argilla/client/resources.py similarity index 100% rename from argilla/src/argilla/client/resources.py rename to extralit/src/argilla/client/resources.py diff --git a/argilla/src/argilla/datasets/__init__.py b/extralit/src/argilla/datasets/__init__.py similarity index 100% rename from argilla/src/argilla/datasets/__init__.py rename to extralit/src/argilla/datasets/__init__.py diff --git a/argilla/src/argilla/datasets/_io/__init__.py b/extralit/src/argilla/datasets/_io/__init__.py similarity index 100% rename from argilla/src/argilla/datasets/_io/__init__.py rename to extralit/src/argilla/datasets/_io/__init__.py diff --git a/argilla/src/argilla/datasets/_io/_disk.py b/extralit/src/argilla/datasets/_io/_disk.py similarity index 98% rename from argilla/src/argilla/datasets/_io/_disk.py rename to extralit/src/argilla/datasets/_io/_disk.py index b0d53ae63..99c1f2ed4 100644 --- a/argilla/src/argilla/datasets/_io/_disk.py +++ b/extralit/src/argilla/datasets/_io/_disk.py @@ -1,4 +1,4 @@ -# Copyright 2024-present, Argilla, Inc. +# Copyright 2024-present, Extralit Labs, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + import json import logging import os @@ -35,7 +36,7 @@ class DiskImportExportMixin(ABC): _model: DatasetModel _DEFAULT_RECORDS_PATH = "records.json" - _DEFAULT_CONFIG_REPO_DIR = ".argilla" + _DEFAULT_CONFIG_REPO_DIR = ".extralit" _DEFAULT_SETTINGS_PATH = f"{_DEFAULT_CONFIG_REPO_DIR}/settings.json" _DEFAULT_DATASET_PATH = f"{_DEFAULT_CONFIG_REPO_DIR}/dataset.json" _DEFAULT_CONFIGURATION_FILES = [_DEFAULT_SETTINGS_PATH, _DEFAULT_DATASET_PATH] diff --git a/argilla/src/argilla/datasets/_io/_hub.py b/extralit/src/argilla/datasets/_io/_hub.py similarity index 100% rename from argilla/src/argilla/datasets/_io/_hub.py rename to extralit/src/argilla/datasets/_io/_hub.py diff --git a/argilla/src/argilla/datasets/_io/card/__init__.py b/extralit/src/argilla/datasets/_io/card/__init__.py similarity index 100% rename from argilla/src/argilla/datasets/_io/card/__init__.py rename to extralit/src/argilla/datasets/_io/card/__init__.py diff --git a/argilla/src/argilla/datasets/_io/card/_dataset_card.py b/extralit/src/argilla/datasets/_io/card/_dataset_card.py similarity index 100% rename from argilla/src/argilla/datasets/_io/card/_dataset_card.py rename to extralit/src/argilla/datasets/_io/card/_dataset_card.py diff --git a/argilla/src/argilla/datasets/_io/card/_parser.py b/extralit/src/argilla/datasets/_io/card/_parser.py similarity index 100% rename from argilla/src/argilla/datasets/_io/card/_parser.py rename to extralit/src/argilla/datasets/_io/card/_parser.py diff --git a/argilla/src/argilla/datasets/_io/card/argilla_template.md b/extralit/src/argilla/datasets/_io/card/argilla_template.md similarity index 100% rename from argilla/src/argilla/datasets/_io/card/argilla_template.md rename to extralit/src/argilla/datasets/_io/card/argilla_template.md diff --git a/argilla/src/argilla/datasets/_resource.py b/extralit/src/argilla/datasets/_resource.py similarity index 100% rename from argilla/src/argilla/datasets/_resource.py rename to extralit/src/argilla/datasets/_resource.py diff --git a/argilla/src/argilla/markdown/__init__.py b/extralit/src/argilla/markdown/__init__.py similarity index 100% rename from argilla/src/argilla/markdown/__init__.py rename to extralit/src/argilla/markdown/__init__.py diff --git a/argilla/src/argilla/markdown/chat.py b/extralit/src/argilla/markdown/chat.py similarity index 100% rename from argilla/src/argilla/markdown/chat.py rename to extralit/src/argilla/markdown/chat.py diff --git a/argilla/src/argilla/markdown/media.py b/extralit/src/argilla/markdown/media.py similarity index 100% rename from argilla/src/argilla/markdown/media.py rename to extralit/src/argilla/markdown/media.py diff --git a/argilla/src/argilla/records/__init__.py b/extralit/src/argilla/records/__init__.py similarity index 100% rename from argilla/src/argilla/records/__init__.py rename to extralit/src/argilla/records/__init__.py diff --git a/argilla/src/argilla/records/_dataset_records.py b/extralit/src/argilla/records/_dataset_records.py similarity index 100% rename from argilla/src/argilla/records/_dataset_records.py rename to extralit/src/argilla/records/_dataset_records.py diff --git a/argilla/src/argilla/records/_io/__init__.py b/extralit/src/argilla/records/_io/__init__.py similarity index 100% rename from argilla/src/argilla/records/_io/__init__.py rename to extralit/src/argilla/records/_io/__init__.py diff --git a/argilla/src/argilla/records/_io/_datasets.py b/extralit/src/argilla/records/_io/_datasets.py similarity index 100% rename from argilla/src/argilla/records/_io/_datasets.py rename to extralit/src/argilla/records/_io/_datasets.py diff --git a/argilla/src/argilla/records/_io/_generic.py b/extralit/src/argilla/records/_io/_generic.py similarity index 100% rename from argilla/src/argilla/records/_io/_generic.py rename to extralit/src/argilla/records/_io/_generic.py diff --git a/argilla/src/argilla/records/_io/_json.py b/extralit/src/argilla/records/_io/_json.py similarity index 100% rename from argilla/src/argilla/records/_io/_json.py rename to extralit/src/argilla/records/_io/_json.py diff --git a/argilla/src/argilla/records/_mapping/__init__.py b/extralit/src/argilla/records/_mapping/__init__.py similarity index 100% rename from argilla/src/argilla/records/_mapping/__init__.py rename to extralit/src/argilla/records/_mapping/__init__.py diff --git a/argilla/src/argilla/records/_mapping/_mapper.py b/extralit/src/argilla/records/_mapping/_mapper.py similarity index 100% rename from argilla/src/argilla/records/_mapping/_mapper.py rename to extralit/src/argilla/records/_mapping/_mapper.py diff --git a/argilla/src/argilla/records/_mapping/_routes.py b/extralit/src/argilla/records/_mapping/_routes.py similarity index 100% rename from argilla/src/argilla/records/_mapping/_routes.py rename to extralit/src/argilla/records/_mapping/_routes.py diff --git a/argilla/src/argilla/records/_resource.py b/extralit/src/argilla/records/_resource.py similarity index 100% rename from argilla/src/argilla/records/_resource.py rename to extralit/src/argilla/records/_resource.py diff --git a/argilla/src/argilla/records/_search.py b/extralit/src/argilla/records/_search.py similarity index 100% rename from argilla/src/argilla/records/_search.py rename to extralit/src/argilla/records/_search.py diff --git a/argilla/src/argilla/responses.py b/extralit/src/argilla/responses.py similarity index 100% rename from argilla/src/argilla/responses.py rename to extralit/src/argilla/responses.py diff --git a/argilla/src/argilla/settings/__init__.py b/extralit/src/argilla/settings/__init__.py similarity index 100% rename from argilla/src/argilla/settings/__init__.py rename to extralit/src/argilla/settings/__init__.py diff --git a/argilla/src/argilla/settings/_common.py b/extralit/src/argilla/settings/_common.py similarity index 100% rename from argilla/src/argilla/settings/_common.py rename to extralit/src/argilla/settings/_common.py diff --git a/argilla/src/argilla/settings/_field.py b/extralit/src/argilla/settings/_field.py similarity index 100% rename from argilla/src/argilla/settings/_field.py rename to extralit/src/argilla/settings/_field.py diff --git a/argilla/src/argilla/settings/_io/__init__.py b/extralit/src/argilla/settings/_io/__init__.py similarity index 100% rename from argilla/src/argilla/settings/_io/__init__.py rename to extralit/src/argilla/settings/_io/__init__.py diff --git a/argilla/src/argilla/settings/_io/_hub.py b/extralit/src/argilla/settings/_io/_hub.py similarity index 100% rename from argilla/src/argilla/settings/_io/_hub.py rename to extralit/src/argilla/settings/_io/_hub.py diff --git a/argilla/src/argilla/settings/_metadata.py b/extralit/src/argilla/settings/_metadata.py similarity index 100% rename from argilla/src/argilla/settings/_metadata.py rename to extralit/src/argilla/settings/_metadata.py diff --git a/argilla/src/argilla/settings/_question.py b/extralit/src/argilla/settings/_question.py similarity index 100% rename from argilla/src/argilla/settings/_question.py rename to extralit/src/argilla/settings/_question.py diff --git a/argilla/src/argilla/settings/_resource.py b/extralit/src/argilla/settings/_resource.py similarity index 100% rename from argilla/src/argilla/settings/_resource.py rename to extralit/src/argilla/settings/_resource.py diff --git a/argilla/src/argilla/settings/_task_distribution.py b/extralit/src/argilla/settings/_task_distribution.py similarity index 100% rename from argilla/src/argilla/settings/_task_distribution.py rename to extralit/src/argilla/settings/_task_distribution.py diff --git a/argilla/src/argilla/settings/_templates.py b/extralit/src/argilla/settings/_templates.py similarity index 100% rename from argilla/src/argilla/settings/_templates.py rename to extralit/src/argilla/settings/_templates.py diff --git a/argilla/src/argilla/settings/_vector.py b/extralit/src/argilla/settings/_vector.py similarity index 100% rename from argilla/src/argilla/settings/_vector.py rename to extralit/src/argilla/settings/_vector.py diff --git a/argilla/src/argilla/suggestions.py b/extralit/src/argilla/suggestions.py similarity index 100% rename from argilla/src/argilla/suggestions.py rename to extralit/src/argilla/suggestions.py diff --git a/argilla/src/argilla/users/__init__.py b/extralit/src/argilla/users/__init__.py similarity index 100% rename from argilla/src/argilla/users/__init__.py rename to extralit/src/argilla/users/__init__.py diff --git a/argilla/src/argilla/users/_resource.py b/extralit/src/argilla/users/_resource.py similarity index 100% rename from argilla/src/argilla/users/_resource.py rename to extralit/src/argilla/users/_resource.py diff --git a/argilla/src/argilla/v1/__init__.py b/extralit/src/argilla/v1/__init__.py similarity index 100% rename from argilla/src/argilla/v1/__init__.py rename to extralit/src/argilla/v1/__init__.py diff --git a/argilla/src/argilla/vectors.py b/extralit/src/argilla/vectors.py similarity index 100% rename from argilla/src/argilla/vectors.py rename to extralit/src/argilla/vectors.py diff --git a/argilla/src/argilla/webhooks/__init__.py b/extralit/src/argilla/webhooks/__init__.py similarity index 100% rename from argilla/src/argilla/webhooks/__init__.py rename to extralit/src/argilla/webhooks/__init__.py diff --git a/argilla/src/argilla/webhooks/_event.py b/extralit/src/argilla/webhooks/_event.py similarity index 100% rename from argilla/src/argilla/webhooks/_event.py rename to extralit/src/argilla/webhooks/_event.py diff --git a/argilla/src/argilla/webhooks/_handler.py b/extralit/src/argilla/webhooks/_handler.py similarity index 100% rename from argilla/src/argilla/webhooks/_handler.py rename to extralit/src/argilla/webhooks/_handler.py diff --git a/argilla/src/argilla/webhooks/_helpers.py b/extralit/src/argilla/webhooks/_helpers.py similarity index 100% rename from argilla/src/argilla/webhooks/_helpers.py rename to extralit/src/argilla/webhooks/_helpers.py diff --git a/argilla/src/argilla/webhooks/_resource.py b/extralit/src/argilla/webhooks/_resource.py similarity index 100% rename from argilla/src/argilla/webhooks/_resource.py rename to extralit/src/argilla/webhooks/_resource.py diff --git a/argilla/src/argilla/workspaces/__init__.py b/extralit/src/argilla/workspaces/__init__.py similarity index 100% rename from argilla/src/argilla/workspaces/__init__.py rename to extralit/src/argilla/workspaces/__init__.py diff --git a/argilla/src/argilla/workspaces/_resource.py b/extralit/src/argilla/workspaces/_resource.py similarity index 100% rename from argilla/src/argilla/workspaces/_resource.py rename to extralit/src/argilla/workspaces/_resource.py diff --git a/argilla/src/extralit/__init__.py b/extralit/src/extralit/__init__.py similarity index 100% rename from argilla/src/extralit/__init__.py rename to extralit/src/extralit/__init__.py diff --git a/argilla/src/extralit/constants.py b/extralit/src/extralit/constants.py similarity index 100% rename from argilla/src/extralit/constants.py rename to extralit/src/extralit/constants.py diff --git a/argilla/src/extralit/convert/__init__.py b/extralit/src/extralit/convert/__init__.py similarity index 100% rename from argilla/src/extralit/convert/__init__.py rename to extralit/src/extralit/convert/__init__.py diff --git a/argilla/src/extralit/convert/html_table.py b/extralit/src/extralit/convert/html_table.py similarity index 100% rename from argilla/src/extralit/convert/html_table.py rename to extralit/src/extralit/convert/html_table.py diff --git a/argilla/src/extralit/convert/json_table.py b/extralit/src/extralit/convert/json_table.py similarity index 100% rename from argilla/src/extralit/convert/json_table.py rename to extralit/src/extralit/convert/json_table.py diff --git a/argilla/src/extralit/convert/markdown.py b/extralit/src/extralit/convert/markdown.py similarity index 100% rename from argilla/src/extralit/convert/markdown.py rename to extralit/src/extralit/convert/markdown.py diff --git a/argilla/src/extralit/convert/pdf.py b/extralit/src/extralit/convert/pdf.py similarity index 100% rename from argilla/src/extralit/convert/pdf.py rename to extralit/src/extralit/convert/pdf.py diff --git a/argilla/src/extralit/convert/text.py b/extralit/src/extralit/convert/text.py similarity index 100% rename from argilla/src/extralit/convert/text.py rename to extralit/src/extralit/convert/text.py diff --git a/argilla/tests/extralit/__init__.py b/extralit/src/extralit/extraction/__init__.py similarity index 100% rename from argilla/tests/extralit/__init__.py rename to extralit/src/extralit/extraction/__init__.py diff --git a/argilla/src/extralit/extraction/chunking.py b/extralit/src/extralit/extraction/chunking.py similarity index 100% rename from argilla/src/extralit/extraction/chunking.py rename to extralit/src/extralit/extraction/chunking.py diff --git a/argilla/src/extralit/extraction/extraction.py b/extralit/src/extralit/extraction/extraction.py similarity index 100% rename from argilla/src/extralit/extraction/extraction.py rename to extralit/src/extralit/extraction/extraction.py diff --git a/argilla/src/extralit/extraction/models/__init__.py b/extralit/src/extralit/extraction/models/__init__.py similarity index 100% rename from argilla/src/extralit/extraction/models/__init__.py rename to extralit/src/extralit/extraction/models/__init__.py diff --git a/argilla/src/extralit/extraction/models/paper.py b/extralit/src/extralit/extraction/models/paper.py similarity index 100% rename from argilla/src/extralit/extraction/models/paper.py rename to extralit/src/extralit/extraction/models/paper.py diff --git a/argilla/src/extralit/extraction/models/response.py b/extralit/src/extralit/extraction/models/response.py similarity index 100% rename from argilla/src/extralit/extraction/models/response.py rename to extralit/src/extralit/extraction/models/response.py diff --git a/argilla/src/extralit/extraction/models/schema.py b/extralit/src/extralit/extraction/models/schema.py similarity index 100% rename from argilla/src/extralit/extraction/models/schema.py rename to extralit/src/extralit/extraction/models/schema.py diff --git a/argilla/src/extralit/extraction/prompts.py b/extralit/src/extralit/extraction/prompts.py similarity index 100% rename from argilla/src/extralit/extraction/prompts.py rename to extralit/src/extralit/extraction/prompts.py diff --git a/argilla/src/extralit/extraction/query.py b/extralit/src/extralit/extraction/query.py similarity index 100% rename from argilla/src/extralit/extraction/query.py rename to extralit/src/extralit/extraction/query.py diff --git a/argilla/src/extralit/extraction/schema.py b/extralit/src/extralit/extraction/schema.py similarity index 100% rename from argilla/src/extralit/extraction/schema.py rename to extralit/src/extralit/extraction/schema.py diff --git a/argilla/src/extralit/extraction/staging.py b/extralit/src/extralit/extraction/staging.py similarity index 100% rename from argilla/src/extralit/extraction/staging.py rename to extralit/src/extralit/extraction/staging.py diff --git a/argilla/src/extralit/extraction/storage.py b/extralit/src/extralit/extraction/storage.py similarity index 100% rename from argilla/src/extralit/extraction/storage.py rename to extralit/src/extralit/extraction/storage.py diff --git a/argilla/src/extralit/extraction/utils.py b/extralit/src/extralit/extraction/utils.py similarity index 100% rename from argilla/src/extralit/extraction/utils.py rename to extralit/src/extralit/extraction/utils.py diff --git a/argilla/src/extralit/extraction/vector_index.py b/extralit/src/extralit/extraction/vector_index.py similarity index 100% rename from argilla/src/extralit/extraction/vector_index.py rename to extralit/src/extralit/extraction/vector_index.py diff --git a/argilla/src/extralit/extraction/vector_store.py b/extralit/src/extralit/extraction/vector_store.py similarity index 100% rename from argilla/src/extralit/extraction/vector_store.py rename to extralit/src/extralit/extraction/vector_store.py diff --git a/argilla/tests/extralit/extraction/__init__.py b/extralit/src/extralit/metrics/__init__.py similarity index 100% rename from argilla/tests/extralit/extraction/__init__.py rename to extralit/src/extralit/metrics/__init__.py diff --git a/argilla/src/extralit/metrics/extraction.py b/extralit/src/extralit/metrics/extraction.py similarity index 100% rename from argilla/src/extralit/metrics/extraction.py rename to extralit/src/extralit/metrics/extraction.py diff --git a/argilla/src/extralit/metrics/grits.py b/extralit/src/extralit/metrics/grits.py similarity index 100% rename from argilla/src/extralit/metrics/grits.py rename to extralit/src/extralit/metrics/grits.py diff --git a/argilla/src/extralit/metrics/utils.py b/extralit/src/extralit/metrics/utils.py similarity index 100% rename from argilla/src/extralit/metrics/utils.py rename to extralit/src/extralit/metrics/utils.py diff --git a/argilla/tests/extralit/preprocessing/__init__.py b/extralit/src/extralit/pipeline/__init__.py similarity index 100% rename from argilla/tests/extralit/preprocessing/__init__.py rename to extralit/src/extralit/pipeline/__init__.py diff --git a/argilla/tests/extralit/server/__init__.py b/extralit/src/extralit/pipeline/export/__init__.py similarity index 100% rename from argilla/tests/extralit/server/__init__.py rename to extralit/src/extralit/pipeline/export/__init__.py diff --git a/argilla/src/extralit/pipeline/export/dataset.py b/extralit/src/extralit/pipeline/export/dataset.py similarity index 100% rename from argilla/src/extralit/pipeline/export/dataset.py rename to extralit/src/extralit/pipeline/export/dataset.py diff --git a/argilla/src/extralit/pipeline/export/record.py b/extralit/src/extralit/pipeline/export/record.py similarity index 100% rename from argilla/src/extralit/pipeline/export/record.py rename to extralit/src/extralit/pipeline/export/record.py diff --git a/argilla/tests/unit/cli/schemas/__init__.py b/extralit/src/extralit/pipeline/ingest/__init__.py similarity index 100% rename from argilla/tests/unit/cli/schemas/__init__.py rename to extralit/src/extralit/pipeline/ingest/__init__.py diff --git a/argilla/src/extralit/pipeline/ingest/paper.py b/extralit/src/extralit/pipeline/ingest/paper.py similarity index 100% rename from argilla/src/extralit/pipeline/ingest/paper.py rename to extralit/src/extralit/pipeline/ingest/paper.py diff --git a/argilla/src/extralit/pipeline/ingest/record.py b/extralit/src/extralit/pipeline/ingest/record.py similarity index 100% rename from argilla/src/extralit/pipeline/ingest/record.py rename to extralit/src/extralit/pipeline/ingest/record.py diff --git a/argilla/src/extralit/pipeline/ingest/segment.py b/extralit/src/extralit/pipeline/ingest/segment.py similarity index 100% rename from argilla/src/extralit/pipeline/ingest/segment.py rename to extralit/src/extralit/pipeline/ingest/segment.py diff --git a/argilla/src/extralit/pipeline/ingest/trace.py b/extralit/src/extralit/pipeline/ingest/trace.py similarity index 100% rename from argilla/src/extralit/pipeline/ingest/trace.py rename to extralit/src/extralit/pipeline/ingest/trace.py diff --git a/extralit/src/extralit/pipeline/update/__init__.py b/extralit/src/extralit/pipeline/update/__init__.py new file mode 100644 index 000000000..6da066d3e --- /dev/null +++ b/extralit/src/extralit/pipeline/update/__init__.py @@ -0,0 +1,2 @@ +from .schema import * +from .suggestion import * diff --git a/argilla/src/extralit/pipeline/update/schema.py b/extralit/src/extralit/pipeline/update/schema.py similarity index 100% rename from argilla/src/extralit/pipeline/update/schema.py rename to extralit/src/extralit/pipeline/update/schema.py diff --git a/argilla/src/extralit/pipeline/update/suggestion.py b/extralit/src/extralit/pipeline/update/suggestion.py similarity index 100% rename from argilla/src/extralit/pipeline/update/suggestion.py rename to extralit/src/extralit/pipeline/update/suggestion.py diff --git a/argilla/tests/unit/cli/schemas/test_delete.py b/extralit/src/extralit/preprocessing/__init__.py similarity index 100% rename from argilla/tests/unit/cli/schemas/test_delete.py rename to extralit/src/extralit/preprocessing/__init__.py diff --git a/argilla/src/extralit/preprocessing/alignment.py b/extralit/src/extralit/preprocessing/alignment.py similarity index 100% rename from argilla/src/extralit/preprocessing/alignment.py rename to extralit/src/extralit/preprocessing/alignment.py diff --git a/argilla/src/extralit/preprocessing/document.py b/extralit/src/extralit/preprocessing/document.py similarity index 100% rename from argilla/src/extralit/preprocessing/document.py rename to extralit/src/extralit/preprocessing/document.py diff --git a/argilla/src/extralit/preprocessing/figures.py b/extralit/src/extralit/preprocessing/figures.py similarity index 100% rename from argilla/src/extralit/preprocessing/figures.py rename to extralit/src/extralit/preprocessing/figures.py diff --git a/argilla/tests/unit/cli/schemas/test_upload.py b/extralit/src/extralit/preprocessing/methods/__init__.py similarity index 100% rename from argilla/tests/unit/cli/schemas/test_upload.py rename to extralit/src/extralit/preprocessing/methods/__init__.py diff --git a/argilla/src/extralit/preprocessing/methods/deepdoctection.py b/extralit/src/extralit/preprocessing/methods/deepdoctection.py similarity index 100% rename from argilla/src/extralit/preprocessing/methods/deepdoctection.py rename to extralit/src/extralit/preprocessing/methods/deepdoctection.py diff --git a/argilla/src/extralit/preprocessing/methods/llmsherpa.py b/extralit/src/extralit/preprocessing/methods/llmsherpa.py similarity index 100% rename from argilla/src/extralit/preprocessing/methods/llmsherpa.py rename to extralit/src/extralit/preprocessing/methods/llmsherpa.py diff --git a/argilla/src/extralit/preprocessing/methods/nougat.py b/extralit/src/extralit/preprocessing/methods/nougat.py similarity index 100% rename from argilla/src/extralit/preprocessing/methods/nougat.py rename to extralit/src/extralit/preprocessing/methods/nougat.py diff --git a/argilla/src/extralit/preprocessing/methods/unstructured.py b/extralit/src/extralit/preprocessing/methods/unstructured.py similarity index 100% rename from argilla/src/extralit/preprocessing/methods/unstructured.py rename to extralit/src/extralit/preprocessing/methods/unstructured.py diff --git a/argilla/src/extralit/preprocessing/segment.py b/extralit/src/extralit/preprocessing/segment.py similarity index 100% rename from argilla/src/extralit/preprocessing/segment.py rename to extralit/src/extralit/preprocessing/segment.py diff --git a/argilla/src/extralit/preprocessing/tables.py b/extralit/src/extralit/preprocessing/tables.py similarity index 100% rename from argilla/src/extralit/preprocessing/tables.py rename to extralit/src/extralit/preprocessing/tables.py diff --git a/argilla/src/extralit/preprocessing/text.py b/extralit/src/extralit/preprocessing/text.py similarity index 100% rename from argilla/src/extralit/preprocessing/text.py rename to extralit/src/extralit/preprocessing/text.py diff --git a/argilla/src/extralit/schema/__init__.py b/extralit/src/extralit/schema/__init__.py similarity index 100% rename from argilla/src/extralit/schema/__init__.py rename to extralit/src/extralit/schema/__init__.py diff --git a/argilla/src/extralit/schema/checks/__init__.py b/extralit/src/extralit/schema/checks/__init__.py similarity index 100% rename from argilla/src/extralit/schema/checks/__init__.py rename to extralit/src/extralit/schema/checks/__init__.py diff --git a/argilla/src/extralit/schema/checks/consistency.py b/extralit/src/extralit/schema/checks/consistency.py similarity index 100% rename from argilla/src/extralit/schema/checks/consistency.py rename to extralit/src/extralit/schema/checks/consistency.py diff --git a/argilla/src/extralit/schema/checks/dataframe.py b/extralit/src/extralit/schema/checks/dataframe.py similarity index 100% rename from argilla/src/extralit/schema/checks/dataframe.py rename to extralit/src/extralit/schema/checks/dataframe.py diff --git a/argilla/src/extralit/schema/checks/join.py b/extralit/src/extralit/schema/checks/join.py similarity index 100% rename from argilla/src/extralit/schema/checks/join.py rename to extralit/src/extralit/schema/checks/join.py diff --git a/argilla/src/extralit/schema/checks/multilabels.py b/extralit/src/extralit/schema/checks/multilabels.py similarity index 100% rename from argilla/src/extralit/schema/checks/multilabels.py rename to extralit/src/extralit/schema/checks/multilabels.py diff --git a/argilla/src/extralit/schema/checks/suggestion.py b/extralit/src/extralit/schema/checks/suggestion.py similarity index 100% rename from argilla/src/extralit/schema/checks/suggestion.py rename to extralit/src/extralit/schema/checks/suggestion.py diff --git a/argilla/src/extralit/schema/checks/time_elapsed.py b/extralit/src/extralit/schema/checks/time_elapsed.py similarity index 100% rename from argilla/src/extralit/schema/checks/time_elapsed.py rename to extralit/src/extralit/schema/checks/time_elapsed.py diff --git a/argilla/src/extralit/schema/checks/utils.py b/extralit/src/extralit/schema/checks/utils.py similarity index 100% rename from argilla/src/extralit/schema/checks/utils.py rename to extralit/src/extralit/schema/checks/utils.py diff --git a/extralit/src/extralit/schema/dtypes/__init__.py b/extralit/src/extralit/schema/dtypes/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/argilla/src/extralit/schema/dtypes/parse.py b/extralit/src/extralit/schema/dtypes/parse.py similarity index 100% rename from argilla/src/extralit/schema/dtypes/parse.py rename to extralit/src/extralit/schema/dtypes/parse.py diff --git a/extralit/src/extralit/schema/references/__init__.py b/extralit/src/extralit/schema/references/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/argilla/src/extralit/schema/references/assign.py b/extralit/src/extralit/schema/references/assign.py similarity index 100% rename from argilla/src/extralit/schema/references/assign.py rename to extralit/src/extralit/schema/references/assign.py diff --git a/argilla/src/extralit/schema/registry.py b/extralit/src/extralit/schema/registry.py similarity index 100% rename from argilla/src/extralit/schema/registry.py rename to extralit/src/extralit/schema/registry.py diff --git a/extralit/src/extralit/server/__init__.py b/extralit/src/extralit/server/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/argilla/src/extralit/server/app.py b/extralit/src/extralit/server/app.py similarity index 100% rename from argilla/src/extralit/server/app.py rename to extralit/src/extralit/server/app.py diff --git a/extralit/src/extralit/server/context/__init__.py b/extralit/src/extralit/server/context/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/argilla/src/extralit/server/context/datasets.py b/extralit/src/extralit/server/context/datasets.py similarity index 100% rename from argilla/src/extralit/server/context/datasets.py rename to extralit/src/extralit/server/context/datasets.py diff --git a/extralit/src/extralit/server/context/files.py b/extralit/src/extralit/server/context/files.py new file mode 100644 index 000000000..81e609b00 --- /dev/null +++ b/extralit/src/extralit/server/context/files.py @@ -0,0 +1,51 @@ +# Copyright 2024-present, Extralit Labs, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Optional +from urllib.parse import urlparse +from minio import Minio +import logging + +_LOGGER = logging.getLogger(__name__) + + +def get_minio_client() -> Optional[Minio]: + s3_endpoint = os.getenv("S3_ENDPOINT") + s3_access_key = os.getenv("S3_ACCESS_KEY") + s3_secret_key = os.getenv("S3_SECRET_KEY") + + if s3_endpoint is None: + return None + + try: + parsed_url = urlparse(s3_endpoint) + hostname = parsed_url.hostname + port = parsed_url.port + + if hostname is None: + _LOGGER.error( + f"Invalid URL: no hostname in S3_ENDPOINT found, possible due to lacking http(s) protocol. Given '{s3_endpoint}'" + ) + return None + + return Minio( + endpoint=f"{hostname}:{port}" if port else hostname, + access_key=s3_access_key, + secret_key=s3_secret_key, + secure=parsed_url.scheme == "https", + ) + except Exception as e: + _LOGGER.error(f"Error creating Minio client: {e}", stack_info=True) + return None diff --git a/argilla/src/extralit/server/context/llamaindex.py b/extralit/src/extralit/server/context/llamaindex.py similarity index 100% rename from argilla/src/extralit/server/context/llamaindex.py rename to extralit/src/extralit/server/context/llamaindex.py diff --git a/argilla/src/extralit/server/context/vectordb.py b/extralit/src/extralit/server/context/vectordb.py similarity index 100% rename from argilla/src/extralit/server/context/vectordb.py rename to extralit/src/extralit/server/context/vectordb.py diff --git a/argilla/src/extralit/server/errors.py b/extralit/src/extralit/server/errors.py similarity index 100% rename from argilla/src/extralit/server/errors.py rename to extralit/src/extralit/server/errors.py diff --git a/extralit/src/extralit/server/models/__init__.py b/extralit/src/extralit/server/models/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/argilla/src/extralit/server/models/extraction.py b/extralit/src/extralit/server/models/extraction.py similarity index 100% rename from argilla/src/extralit/server/models/extraction.py rename to extralit/src/extralit/server/models/extraction.py diff --git a/argilla/src/extralit/server/models/segments.py b/extralit/src/extralit/server/models/segments.py similarity index 100% rename from argilla/src/extralit/server/models/segments.py rename to extralit/src/extralit/server/models/segments.py diff --git a/extralit/src/extralit/storage/__init__.py b/extralit/src/extralit/storage/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/argilla/src/extralit/storage/files.py b/extralit/src/extralit/storage/files.py similarity index 100% rename from argilla/src/extralit/storage/files.py rename to extralit/src/extralit/storage/files.py diff --git a/argilla/src/extralit/storage/singleton.py b/extralit/src/extralit/storage/singleton.py similarity index 100% rename from argilla/src/extralit/storage/singleton.py rename to extralit/src/extralit/storage/singleton.py diff --git a/extralit/tests/extralit/__init__.py b/extralit/tests/extralit/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/argilla/tests/extralit/assets/schemas/itncondition.yaml b/extralit/tests/extralit/assets/schemas/itncondition.yaml similarity index 100% rename from argilla/tests/extralit/assets/schemas/itncondition.yaml rename to extralit/tests/extralit/assets/schemas/itncondition.yaml diff --git a/argilla/tests/extralit/assets/schemas/observation.yaml b/extralit/tests/extralit/assets/schemas/observation.yaml similarity index 100% rename from argilla/tests/extralit/assets/schemas/observation.yaml rename to extralit/tests/extralit/assets/schemas/observation.yaml diff --git a/argilla/tests/extralit/assets/schemas/publication.yaml b/extralit/tests/extralit/assets/schemas/publication.yaml similarity index 100% rename from argilla/tests/extralit/assets/schemas/publication.yaml rename to extralit/tests/extralit/assets/schemas/publication.yaml diff --git a/argilla/tests/extralit/conftest.py b/extralit/tests/extralit/conftest.py similarity index 100% rename from argilla/tests/extralit/conftest.py rename to extralit/tests/extralit/conftest.py diff --git a/extralit/tests/extralit/extraction/__init__.py b/extralit/tests/extralit/extraction/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/argilla/tests/extralit/helpers.py b/extralit/tests/extralit/helpers.py similarity index 100% rename from argilla/tests/extralit/helpers.py rename to extralit/tests/extralit/helpers.py diff --git a/extralit/tests/extralit/metrics/__init__.py b/extralit/tests/extralit/metrics/__init__.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/extralit/tests/extralit/metrics/__init__.py @@ -0,0 +1 @@ + diff --git a/argilla/tests/extralit/metrics/conftest.py b/extralit/tests/extralit/metrics/conftest.py similarity index 100% rename from argilla/tests/extralit/metrics/conftest.py rename to extralit/tests/extralit/metrics/conftest.py diff --git a/argilla/tests/extralit/metrics/test_extraction.py b/extralit/tests/extralit/metrics/test_extraction.py similarity index 100% rename from argilla/tests/extralit/metrics/test_extraction.py rename to extralit/tests/extralit/metrics/test_extraction.py diff --git a/argilla/tests/extralit/metrics/test_grits.py b/extralit/tests/extralit/metrics/test_grits.py similarity index 100% rename from argilla/tests/extralit/metrics/test_grits.py rename to extralit/tests/extralit/metrics/test_grits.py diff --git a/argilla/tests/extralit/metrics/test_utils.py b/extralit/tests/extralit/metrics/test_utils.py similarity index 100% rename from argilla/tests/extralit/metrics/test_utils.py rename to extralit/tests/extralit/metrics/test_utils.py diff --git a/extralit/tests/extralit/preprocessing/__init__.py b/extralit/tests/extralit/preprocessing/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/argilla/tests/extralit/preprocessing/conftest.py b/extralit/tests/extralit/preprocessing/conftest.py similarity index 100% rename from argilla/tests/extralit/preprocessing/conftest.py rename to extralit/tests/extralit/preprocessing/conftest.py diff --git a/argilla/tests/extralit/preprocessing/test_document.py b/extralit/tests/extralit/preprocessing/test_document.py similarity index 100% rename from argilla/tests/extralit/preprocessing/test_document.py rename to extralit/tests/extralit/preprocessing/test_document.py diff --git a/extralit/tests/extralit/server/__init__.py b/extralit/tests/extralit/server/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/argilla/tests/extralit/server/test_app.py b/extralit/tests/extralit/server/test_app.py similarity index 100% rename from argilla/tests/extralit/server/test_app.py rename to extralit/tests/extralit/server/test_app.py diff --git a/argilla/tests/extralit/utils.py b/extralit/tests/extralit/utils.py similarity index 100% rename from argilla/tests/extralit/utils.py rename to extralit/tests/extralit/utils.py diff --git a/argilla/tests/integration/conftest.py b/extralit/tests/integration/conftest.py similarity index 100% rename from argilla/tests/integration/conftest.py rename to extralit/tests/integration/conftest.py diff --git a/argilla/tests/integration/test_add_records.py b/extralit/tests/integration/test_add_records.py similarity index 100% rename from argilla/tests/integration/test_add_records.py rename to extralit/tests/integration/test_add_records.py diff --git a/argilla/tests/integration/test_cli_commands.py b/extralit/tests/integration/test_cli_commands.py similarity index 100% rename from argilla/tests/integration/test_cli_commands.py rename to extralit/tests/integration/test_cli_commands.py diff --git a/argilla/tests/integration/test_client.py b/extralit/tests/integration/test_client.py similarity index 100% rename from argilla/tests/integration/test_client.py rename to extralit/tests/integration/test_client.py diff --git a/argilla/tests/integration/test_create_datasets.py b/extralit/tests/integration/test_create_datasets.py similarity index 100% rename from argilla/tests/integration/test_create_datasets.py rename to extralit/tests/integration/test_create_datasets.py diff --git a/argilla/tests/integration/test_dataset_workspace.py b/extralit/tests/integration/test_dataset_workspace.py similarity index 100% rename from argilla/tests/integration/test_dataset_workspace.py rename to extralit/tests/integration/test_dataset_workspace.py diff --git a/argilla/tests/integration/test_delete_records.py b/extralit/tests/integration/test_delete_records.py similarity index 100% rename from argilla/tests/integration/test_delete_records.py rename to extralit/tests/integration/test_delete_records.py diff --git a/argilla/tests/integration/test_empty_settings.py b/extralit/tests/integration/test_empty_settings.py similarity index 100% rename from argilla/tests/integration/test_empty_settings.py rename to extralit/tests/integration/test_empty_settings.py diff --git a/argilla/tests/integration/test_export_dataset.py b/extralit/tests/integration/test_export_dataset.py similarity index 100% rename from argilla/tests/integration/test_export_dataset.py rename to extralit/tests/integration/test_export_dataset.py diff --git a/argilla/tests/integration/test_export_records.py b/extralit/tests/integration/test_export_records.py similarity index 100% rename from argilla/tests/integration/test_export_records.py rename to extralit/tests/integration/test_export_records.py diff --git a/argilla/tests/integration/test_import_features.py b/extralit/tests/integration/test_import_features.py similarity index 100% rename from argilla/tests/integration/test_import_features.py rename to extralit/tests/integration/test_import_features.py diff --git a/argilla/tests/integration/test_list_records.py b/extralit/tests/integration/test_list_records.py similarity index 100% rename from argilla/tests/integration/test_list_records.py rename to extralit/tests/integration/test_list_records.py diff --git a/argilla/tests/integration/test_listing_datasets.py b/extralit/tests/integration/test_listing_datasets.py similarity index 100% rename from argilla/tests/integration/test_listing_datasets.py rename to extralit/tests/integration/test_listing_datasets.py diff --git a/argilla/tests/integration/test_manage_metadata.py b/extralit/tests/integration/test_manage_metadata.py similarity index 100% rename from argilla/tests/integration/test_manage_metadata.py rename to extralit/tests/integration/test_manage_metadata.py diff --git a/argilla/tests/integration/test_manage_users.py b/extralit/tests/integration/test_manage_users.py similarity index 100% rename from argilla/tests/integration/test_manage_users.py rename to extralit/tests/integration/test_manage_users.py diff --git a/argilla/tests/integration/test_manage_workspaces.py b/extralit/tests/integration/test_manage_workspaces.py similarity index 100% rename from argilla/tests/integration/test_manage_workspaces.py rename to extralit/tests/integration/test_manage_workspaces.py diff --git a/argilla/tests/integration/test_publish_datasets.py b/extralit/tests/integration/test_publish_datasets.py similarity index 100% rename from argilla/tests/integration/test_publish_datasets.py rename to extralit/tests/integration/test_publish_datasets.py diff --git a/argilla/tests/integration/test_query_records.py b/extralit/tests/integration/test_query_records.py similarity index 100% rename from argilla/tests/integration/test_query_records.py rename to extralit/tests/integration/test_query_records.py diff --git a/argilla/tests/integration/test_ranking_questions.py b/extralit/tests/integration/test_ranking_questions.py similarity index 100% rename from argilla/tests/integration/test_ranking_questions.py rename to extralit/tests/integration/test_ranking_questions.py diff --git a/argilla/tests/integration/test_search_records.py b/extralit/tests/integration/test_search_records.py similarity index 97% rename from argilla/tests/integration/test_search_records.py rename to extralit/tests/integration/test_search_records.py index 5ba82d229..a0b729ce5 100644 --- a/argilla/tests/integration/test_search_records.py +++ b/extralit/tests/integration/test_search_records.py @@ -194,6 +194,10 @@ def test_search_records_by_least_similar_value(self, client: Argilla, dataset: D ) ) ) + + if records and str(data[3]["id"]) == records[0][0].id: + pytest.skip("Random tie: least similar record is the same as the query record. Skipping flaky test.") + assert records[0][0].id != str(data[3]["id"]) def test_search_records_by_similar_record(self, client: Argilla, dataset: Dataset): diff --git a/argilla/tests/integration/test_update_dataset_settings.py b/extralit/tests/integration/test_update_dataset_settings.py similarity index 100% rename from argilla/tests/integration/test_update_dataset_settings.py rename to extralit/tests/integration/test_update_dataset_settings.py diff --git a/argilla/tests/integration/test_update_records.py b/extralit/tests/integration/test_update_records.py similarity index 100% rename from argilla/tests/integration/test_update_records.py rename to extralit/tests/integration/test_update_records.py diff --git a/argilla/tests/integration/test_vectors.py b/extralit/tests/integration/test_vectors.py similarity index 100% rename from argilla/tests/integration/test_vectors.py rename to extralit/tests/integration/test_vectors.py diff --git a/argilla/tests/integration/test_workspace_documents.py b/extralit/tests/integration/test_workspace_documents.py similarity index 100% rename from argilla/tests/integration/test_workspace_documents.py rename to extralit/tests/integration/test_workspace_documents.py diff --git a/argilla/tests/integration/test_workspace_files.py b/extralit/tests/integration/test_workspace_files.py similarity index 100% rename from argilla/tests/integration/test_workspace_files.py rename to extralit/tests/integration/test_workspace_files.py diff --git a/argilla/tests/integration/test_workspace_schemas.py b/extralit/tests/integration/test_workspace_schemas.py similarity index 100% rename from argilla/tests/integration/test_workspace_schemas.py rename to extralit/tests/integration/test_workspace_schemas.py diff --git a/argilla/tests/unit/api/http/test_http_client.py b/extralit/tests/unit/api/http/test_http_client.py similarity index 100% rename from argilla/tests/unit/api/http/test_http_client.py rename to extralit/tests/unit/api/http/test_http_client.py diff --git a/argilla/tests/unit/api/test_workspace_files_api.py b/extralit/tests/unit/api/test_workspace_files_api.py similarity index 100% rename from argilla/tests/unit/api/test_workspace_files_api.py rename to extralit/tests/unit/api/test_workspace_files_api.py diff --git a/argilla/tests/unit/api/test_workspace_schemas_api.py b/extralit/tests/unit/api/test_workspace_schemas_api.py similarity index 100% rename from argilla/tests/unit/api/test_workspace_schemas_api.py rename to extralit/tests/unit/api/test_workspace_schemas_api.py diff --git a/extralit/tests/unit/cli/schemas/__init__.py b/extralit/tests/unit/cli/schemas/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/extralit/tests/unit/cli/schemas/test_delete.py b/extralit/tests/unit/cli/schemas/test_delete.py new file mode 100644 index 000000000..e69de29bb diff --git a/extralit/tests/unit/cli/schemas/test_upload.py b/extralit/tests/unit/cli/schemas/test_upload.py new file mode 100644 index 000000000..e69de29bb diff --git a/argilla/tests/unit/cli/test_cli_app.py b/extralit/tests/unit/cli/test_cli_app.py similarity index 100% rename from argilla/tests/unit/cli/test_cli_app.py rename to extralit/tests/unit/cli/test_cli_app.py diff --git a/argilla/tests/unit/cli/test_cli_datasets.py b/extralit/tests/unit/cli/test_cli_datasets.py similarity index 100% rename from argilla/tests/unit/cli/test_cli_datasets.py rename to extralit/tests/unit/cli/test_cli_datasets.py diff --git a/argilla/tests/unit/cli/test_cli_extraction.py b/extralit/tests/unit/cli/test_cli_extraction.py similarity index 100% rename from argilla/tests/unit/cli/test_cli_extraction.py rename to extralit/tests/unit/cli/test_cli_extraction.py diff --git a/argilla/tests/unit/cli/test_cli_schemas.py b/extralit/tests/unit/cli/test_cli_schemas.py similarity index 100% rename from argilla/tests/unit/cli/test_cli_schemas.py rename to extralit/tests/unit/cli/test_cli_schemas.py diff --git a/argilla/tests/unit/cli/test_cli_training.py b/extralit/tests/unit/cli/test_cli_training.py similarity index 100% rename from argilla/tests/unit/cli/test_cli_training.py rename to extralit/tests/unit/cli/test_cli_training.py diff --git a/argilla/tests/unit/cli/test_cli_users.py b/extralit/tests/unit/cli/test_cli_users.py similarity index 100% rename from argilla/tests/unit/cli/test_cli_users.py rename to extralit/tests/unit/cli/test_cli_users.py diff --git a/argilla/tests/unit/cli/test_cli_workspaces.py b/extralit/tests/unit/cli/test_cli_workspaces.py similarity index 100% rename from argilla/tests/unit/cli/test_cli_workspaces.py rename to extralit/tests/unit/cli/test_cli_workspaces.py diff --git a/argilla/tests/unit/conftest.py b/extralit/tests/unit/conftest.py similarity index 100% rename from argilla/tests/unit/conftest.py rename to extralit/tests/unit/conftest.py diff --git a/argilla/tests/unit/export/test_record_export_import_compatibillity.py b/extralit/tests/unit/export/test_record_export_import_compatibillity.py similarity index 100% rename from argilla/tests/unit/export/test_record_export_import_compatibillity.py rename to extralit/tests/unit/export/test_record_export_import_compatibillity.py diff --git a/argilla/tests/unit/export/test_settings_export_import_compatibillity.py b/extralit/tests/unit/export/test_settings_export_import_compatibillity.py similarity index 100% rename from argilla/tests/unit/export/test_settings_export_import_compatibillity.py rename to extralit/tests/unit/export/test_settings_export_import_compatibillity.py diff --git a/argilla/tests/unit/helpers/test_resource_repr.py b/extralit/tests/unit/helpers/test_resource_repr.py similarity index 100% rename from argilla/tests/unit/helpers/test_resource_repr.py rename to extralit/tests/unit/helpers/test_resource_repr.py diff --git a/argilla/tests/unit/helpers/test_spaces_deployment.py b/extralit/tests/unit/helpers/test_spaces_deployment.py similarity index 100% rename from argilla/tests/unit/helpers/test_spaces_deployment.py rename to extralit/tests/unit/helpers/test_spaces_deployment.py diff --git a/argilla/tests/unit/models/__init__.py b/extralit/tests/unit/models/__init__.py similarity index 100% rename from argilla/tests/unit/models/__init__.py rename to extralit/tests/unit/models/__init__.py diff --git a/argilla/tests/unit/models/test_workspace_models.py b/extralit/tests/unit/models/test_workspace_models.py similarity index 100% rename from argilla/tests/unit/models/test_workspace_models.py rename to extralit/tests/unit/models/test_workspace_models.py diff --git a/argilla/tests/unit/test_interface.py b/extralit/tests/unit/test_interface.py similarity index 100% rename from argilla/tests/unit/test_interface.py rename to extralit/tests/unit/test_interface.py diff --git a/argilla/tests/unit/test_io/test_generic.py b/extralit/tests/unit/test_io/test_generic.py similarity index 100% rename from argilla/tests/unit/test_io/test_generic.py rename to extralit/tests/unit/test_io/test_generic.py diff --git a/argilla/tests/unit/test_io/test_hf_datasets.py b/extralit/tests/unit/test_io/test_hf_datasets.py similarity index 100% rename from argilla/tests/unit/test_io/test_hf_datasets.py rename to extralit/tests/unit/test_io/test_hf_datasets.py diff --git a/argilla/tests/unit/test_markdown.py b/extralit/tests/unit/test_markdown.py similarity index 100% rename from argilla/tests/unit/test_markdown.py rename to extralit/tests/unit/test_markdown.py diff --git a/argilla/tests/unit/test_media.py b/extralit/tests/unit/test_media.py similarity index 100% rename from argilla/tests/unit/test_media.py rename to extralit/tests/unit/test_media.py diff --git a/argilla/tests/unit/test_record_fields.py b/extralit/tests/unit/test_record_fields.py similarity index 100% rename from argilla/tests/unit/test_record_fields.py rename to extralit/tests/unit/test_record_fields.py diff --git a/argilla/tests/unit/test_record_ingestion.py b/extralit/tests/unit/test_record_ingestion.py similarity index 100% rename from argilla/tests/unit/test_record_ingestion.py rename to extralit/tests/unit/test_record_ingestion.py diff --git a/argilla/tests/unit/test_record_responses.py b/extralit/tests/unit/test_record_responses.py similarity index 100% rename from argilla/tests/unit/test_record_responses.py rename to extralit/tests/unit/test_record_responses.py diff --git a/argilla/tests/unit/test_record_suggestions.py b/extralit/tests/unit/test_record_suggestions.py similarity index 100% rename from argilla/tests/unit/test_record_suggestions.py rename to extralit/tests/unit/test_record_suggestions.py diff --git a/argilla/tests/unit/test_resources/test_datasets.py b/extralit/tests/unit/test_resources/test_datasets.py similarity index 100% rename from argilla/tests/unit/test_resources/test_datasets.py rename to extralit/tests/unit/test_resources/test_datasets.py diff --git a/argilla/tests/unit/test_resources/test_fields.py b/extralit/tests/unit/test_resources/test_fields.py similarity index 100% rename from argilla/tests/unit/test_resources/test_fields.py rename to extralit/tests/unit/test_resources/test_fields.py diff --git a/argilla/tests/unit/test_resources/test_questions.py b/extralit/tests/unit/test_resources/test_questions.py similarity index 100% rename from argilla/tests/unit/test_resources/test_questions.py rename to extralit/tests/unit/test_resources/test_questions.py diff --git a/argilla/tests/unit/test_resources/test_records.py b/extralit/tests/unit/test_resources/test_records.py similarity index 100% rename from argilla/tests/unit/test_resources/test_records.py rename to extralit/tests/unit/test_resources/test_records.py diff --git a/argilla/tests/unit/test_resources/test_responses.py b/extralit/tests/unit/test_resources/test_responses.py similarity index 100% rename from argilla/tests/unit/test_resources/test_responses.py rename to extralit/tests/unit/test_resources/test_responses.py diff --git a/argilla/tests/unit/test_resources/test_users.py b/extralit/tests/unit/test_resources/test_users.py similarity index 100% rename from argilla/tests/unit/test_resources/test_users.py rename to extralit/tests/unit/test_resources/test_users.py diff --git a/argilla/tests/unit/test_resources/test_workspaces.py b/extralit/tests/unit/test_resources/test_workspaces.py similarity index 100% rename from argilla/tests/unit/test_resources/test_workspaces.py rename to extralit/tests/unit/test_resources/test_workspaces.py diff --git a/argilla/tests/unit/test_search/test_filters.py b/extralit/tests/unit/test_search/test_filters.py similarity index 100% rename from argilla/tests/unit/test_search/test_filters.py rename to extralit/tests/unit/test_search/test_filters.py diff --git a/argilla/tests/unit/test_settings/test_metadata.py b/extralit/tests/unit/test_settings/test_metadata.py similarity index 100% rename from argilla/tests/unit/test_settings/test_metadata.py rename to extralit/tests/unit/test_settings/test_metadata.py diff --git a/argilla/tests/unit/test_settings/test_multi_label_question.py b/extralit/tests/unit/test_settings/test_multi_label_question.py similarity index 100% rename from argilla/tests/unit/test_settings/test_multi_label_question.py rename to extralit/tests/unit/test_settings/test_multi_label_question.py diff --git a/argilla/tests/unit/test_settings/test_settings.py b/extralit/tests/unit/test_settings/test_settings.py similarity index 100% rename from argilla/tests/unit/test_settings/test_settings.py rename to extralit/tests/unit/test_settings/test_settings.py diff --git a/argilla/tests/unit/test_settings/test_settings_fields.py b/extralit/tests/unit/test_settings/test_settings_fields.py similarity index 100% rename from argilla/tests/unit/test_settings/test_settings_fields.py rename to extralit/tests/unit/test_settings/test_settings_fields.py diff --git a/argilla/tests/unit/test_settings/test_settings_from_features.py b/extralit/tests/unit/test_settings/test_settings_from_features.py similarity index 100% rename from argilla/tests/unit/test_settings/test_settings_from_features.py rename to extralit/tests/unit/test_settings/test_settings_from_features.py diff --git a/argilla/tests/unit/test_settings/test_settings_mapping_record_ingestion.py b/extralit/tests/unit/test_settings/test_settings_mapping_record_ingestion.py similarity index 100% rename from argilla/tests/unit/test_settings/test_settings_mapping_record_ingestion.py rename to extralit/tests/unit/test_settings/test_settings_mapping_record_ingestion.py diff --git a/argilla/tests/unit/test_settings/test_settings_questions.py b/extralit/tests/unit/test_settings/test_settings_questions.py similarity index 100% rename from argilla/tests/unit/test_settings/test_settings_questions.py rename to extralit/tests/unit/test_settings/test_settings_questions.py diff --git a/argilla/tests/unit/test_settings/test_settings_templates.py b/extralit/tests/unit/test_settings/test_settings_templates.py similarity index 100% rename from argilla/tests/unit/test_settings/test_settings_templates.py rename to extralit/tests/unit/test_settings/test_settings_templates.py diff --git a/argilla/tests/unit/test_settings/test_span_question.py b/extralit/tests/unit/test_settings/test_span_question.py similarity index 100% rename from argilla/tests/unit/test_settings/test_span_question.py rename to extralit/tests/unit/test_settings/test_span_question.py