# Copyright 2025-, Semiotic AI, Inc.
# SPDX-License-Identifier: Apache-2.0
import logging
# system packages
import os
from enum import Enum
from pathlib import Path
from dotenv import load_dotenv
# external packages
from pytest import fixture
# internal packages
from graphdoc import (
DocGeneratorPrompt,
DocQualityPrompt,
LocalDataHelper,
Parser,
setup_logging,
)
# logging
setup_logging("INFO")
log = logging.getLogger(__name__)
# define test asset paths
TEST_DIR = Path(__file__).resolve().parent
ASSETS_DIR = TEST_DIR / "assets"
MLRUNS_DIR = ASSETS_DIR / "mlruns"
ENV_PATH = TEST_DIR / ".env"
# set the environment variables
os.environ["MLFLOW_TRACKING_URI"] = str(MLRUNS_DIR)
os.environ["MLFLOW_TRACKING_USERNAME"] = "admin"
os.environ["MLFLOW_TRACKING_PASSWORD"] = "password"
# Check if .env file exists
if not ENV_PATH.exists():
log.error(f".env file not found at {ENV_PATH}")
else:
log.info(f".env file found at {ENV_PATH}")
load_dotenv(dotenv_path=ENV_PATH, override=True)
# Set default environment variables if not present
[docs]
def ensure_env_vars():
"""Ensure all required environment variables are set with defaults if needed."""
env_defaults = {
"OPENAI_API_KEY": None, # No default, must be provided
"HF_DATASET_KEY": None, # No default, must be provided
"MLFLOW_TRACKING_URI": str(MLRUNS_DIR),
}
log.info(f"Environment variable path: {ENV_PATH}")
for key in env_defaults:
value = os.environ.get(key, "NOT SET")
if value != "NOT SET":
if "API_KEY" in key or "DATASET_KEY" in key:
log.info(f"Environment variable {key}: SET (value masked)")
else:
log.info(f"Environment variable {key}: SET to {value}")
else:
log.info(f"Environment variable {key}: NOT SET")
for key, default in env_defaults.items():
if key not in os.environ and default is not None:
os.environ[key] = default
log.info(f"Setting default for {key}: {default}")
elif key not in os.environ and default is None:
log.warning(f"Required environment variable {key} not set")
[docs]
@fixture(autouse=True, scope="session")
def setup_env():
"""Fixture to ensure environment is properly set up before each test."""
if ENV_PATH.exists():
load_dotenv(dotenv_path=ENV_PATH, override=True)
ensure_env_vars()
[docs]
class OverwriteSchemaCategory(Enum):
PERFECT = "perfect (TEST)"
ALMOST_PERFECT = "almost perfect (TEST)"
POOR_BUT_CORRECT = "poor but correct (TEST)"
INCORRECT = "incorrect (TEST)"
BLANK = "blank (TEST)"
[docs]
class OverwriteSchemaRating(Enum):
FOUR = "8"
THREE = "6"
TWO = "4"
ONE = "2"
ZERO = "0"
[docs]
class OverwriteSchemaCategoryRatingMapping:
[docs]
def get_rating(self, category: OverwriteSchemaCategory) -> OverwriteSchemaRating:
mapping = {
OverwriteSchemaCategory.PERFECT: OverwriteSchemaRating.FOUR,
OverwriteSchemaCategory.ALMOST_PERFECT: OverwriteSchemaRating.THREE,
OverwriteSchemaCategory.POOR_BUT_CORRECT: OverwriteSchemaRating.TWO,
OverwriteSchemaCategory.INCORRECT: OverwriteSchemaRating.ONE,
OverwriteSchemaCategory.BLANK: OverwriteSchemaRating.ZERO,
}
return mapping.get(category, OverwriteSchemaRating.ZERO)
[docs]
@fixture
def par() -> Parser:
return Parser()
[docs]
@fixture
def default_ldh() -> LocalDataHelper:
return LocalDataHelper()
[docs]
@fixture
def overwrite_ldh() -> LocalDataHelper:
return LocalDataHelper(
categories=OverwriteSchemaCategory,
ratings=OverwriteSchemaRating,
categories_ratings=OverwriteSchemaCategoryRatingMapping.get_rating,
)
[docs]
@fixture
def dqp():
return DocQualityPrompt(
prompt="doc_quality",
prompt_type="predict",
prompt_metric="rating",
)
[docs]
@fixture
def dgp():
return DocGeneratorPrompt(
prompt="base_doc_gen",
prompt_type="chain_of_thought",
prompt_metric=DocQualityPrompt(
prompt="doc_quality",
prompt_type="predict",
prompt_metric="rating",
),
)
[docs]
@fixture
def mlflow_dict():
return {
"mlflow_tracking_uri": MLRUNS_DIR,
"mlflow_tracking_username": "admin",
"mlflow_tracking_password": "password",
}