Source code for tests.conftest

# Copyright 2025-, Semiotic AI, Inc.
# SPDX-License-Identifier: Apache-2.0

import logging

# system packages
import os
from enum import Enum
from pathlib import Path

from dotenv import load_dotenv

# external packages
from pytest import fixture

# internal packages
from graphdoc import (
    DocGeneratorPrompt,
    DocQualityPrompt,
    LocalDataHelper,
    Parser,
    setup_logging,
)

# logging
setup_logging("INFO")
log = logging.getLogger(__name__)

# define test asset paths
TEST_DIR = Path(__file__).resolve().parent
ASSETS_DIR = TEST_DIR / "assets"
MLRUNS_DIR = ASSETS_DIR / "mlruns"
ENV_PATH = TEST_DIR / ".env"

# set the environment variables
os.environ["MLFLOW_TRACKING_URI"] = str(MLRUNS_DIR)
os.environ["MLFLOW_TRACKING_USERNAME"] = "admin"
os.environ["MLFLOW_TRACKING_PASSWORD"] = "password"

# Check if .env file exists
if not ENV_PATH.exists():
    log.error(f".env file not found at {ENV_PATH}")
else:
    log.info(f".env file found at {ENV_PATH}")
    load_dotenv(dotenv_path=ENV_PATH, override=True)


# Set default environment variables if not present
[docs] def ensure_env_vars(): """Ensure all required environment variables are set with defaults if needed.""" env_defaults = { "OPENAI_API_KEY": None, # No default, must be provided "HF_DATASET_KEY": None, # No default, must be provided "MLFLOW_TRACKING_URI": str(MLRUNS_DIR), } log.info(f"Environment variable path: {ENV_PATH}") for key in env_defaults: value = os.environ.get(key, "NOT SET") if value != "NOT SET": if "API_KEY" in key or "DATASET_KEY" in key: log.info(f"Environment variable {key}: SET (value masked)") else: log.info(f"Environment variable {key}: SET to {value}") else: log.info(f"Environment variable {key}: NOT SET") for key, default in env_defaults.items(): if key not in os.environ and default is not None: os.environ[key] = default log.info(f"Setting default for {key}: {default}") elif key not in os.environ and default is None: log.warning(f"Required environment variable {key} not set")
[docs] @fixture(autouse=True, scope="session") def setup_env(): """Fixture to ensure environment is properly set up before each test.""" if ENV_PATH.exists(): load_dotenv(dotenv_path=ENV_PATH, override=True) ensure_env_vars()
[docs] class OverwriteSchemaCategory(Enum): PERFECT = "perfect (TEST)" ALMOST_PERFECT = "almost perfect (TEST)" POOR_BUT_CORRECT = "poor but correct (TEST)" INCORRECT = "incorrect (TEST)" BLANK = "blank (TEST)"
[docs] class OverwriteSchemaRating(Enum): FOUR = "8" THREE = "6" TWO = "4" ONE = "2" ZERO = "0"
[docs] class OverwriteSchemaCategoryRatingMapping:
[docs] def get_rating(self, category: OverwriteSchemaCategory) -> OverwriteSchemaRating: mapping = { OverwriteSchemaCategory.PERFECT: OverwriteSchemaRating.FOUR, OverwriteSchemaCategory.ALMOST_PERFECT: OverwriteSchemaRating.THREE, OverwriteSchemaCategory.POOR_BUT_CORRECT: OverwriteSchemaRating.TWO, OverwriteSchemaCategory.INCORRECT: OverwriteSchemaRating.ONE, OverwriteSchemaCategory.BLANK: OverwriteSchemaRating.ZERO, } return mapping.get(category, OverwriteSchemaRating.ZERO)
[docs] @fixture def par() -> Parser: return Parser()
[docs] @fixture def default_ldh() -> LocalDataHelper: return LocalDataHelper()
[docs] @fixture def overwrite_ldh() -> LocalDataHelper: return LocalDataHelper( categories=OverwriteSchemaCategory, ratings=OverwriteSchemaRating, categories_ratings=OverwriteSchemaCategoryRatingMapping.get_rating, )
[docs] @fixture def dqp(): return DocQualityPrompt( prompt="doc_quality", prompt_type="predict", prompt_metric="rating", )
[docs] @fixture def dgp(): return DocGeneratorPrompt( prompt="base_doc_gen", prompt_type="chain_of_thought", prompt_metric=DocQualityPrompt( prompt="doc_quality", prompt_type="predict", prompt_metric="rating", ), )
[docs] @fixture def mlflow_dict(): return { "mlflow_tracking_uri": MLRUNS_DIR, "mlflow_tracking_username": "admin", "mlflow_tracking_password": "password", }