Skip to content

Feature/adi skillset #6

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
690 changes: 690 additions & 0 deletions ai_search_with_adi/ai_search.py

Large diffs are not rendered by default.

67 changes: 67 additions & 0 deletions ai_search_with_adi/deploy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import argparse
from environment import get_search_endpoint, get_managed_identity_id, get_search_key,get_key_vault_url
from azure.core.credentials import AzureKeyCredential
from azure.identity import DefaultAzureCredential,ManagedIdentityCredential,EnvironmentCredential
from azure.keyvault.secrets import SecretClient
from inquiry_document import InquiryDocumentAISearch


def main(args):
endpoint = get_search_endpoint()

try:
credential = DefaultAzureCredential(managed_identity_client_id =get_managed_identity_id())
# initializing key vault client
client = SecretClient(vault_url=get_key_vault_url(), credential=credential)
print("Using managed identity credential")
except Exception as e:
print(e)
credential = (
AzureKeyCredential(get_search_key(client=client))
)
print("Using Azure Key credential")

if args.indexer_type == "inquiry":
# Deploy the inquiry index
index_config = InquiryDocumentAISearch(
endpoint=endpoint,
credential=credential,
suffix=args.suffix,
rebuild=args.rebuild,
enable_page_by_chunking=args.enable_page_chunking
)
index_config.deploy()

if args.rebuild:
index_config.reset_indexer()


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Process some arguments.")
parser.add_argument(
"--indexer_type",
type=str,
required=True,
help="Type of Indexer want to deploy. inquiry/summary/glossary",
)
parser.add_argument(
"--rebuild",
type=bool,
required=False,
help="Whether want to delete and rebuild the index",
)
parser.add_argument(
"--enable_page_chunking",
type=bool,
required=False,
help="Whether want to enable chunking by page in adi skill, if no value is passed considered False",
)
parser.add_argument(
"--suffix",
type=str,
required=False,
help="Suffix to be attached to indexer objects",
)

args = parser.parse_args()
main(args)
192 changes: 192 additions & 0 deletions ai_search_with_adi/environment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
"""Module providing environment definition"""
import os
from dotenv import find_dotenv, load_dotenv
from enum import Enum

load_dotenv(find_dotenv())


class IndexerType(Enum):
"""The type of the indexer"""

INQUIRY_DOCUMENT = "inquiry-document"
SUMMARY_DOCUMENT = "summary-document"
BUSINESS_GLOSSARY = "business-glossary"

# key vault
def get_key_vault_url() ->str:
"""
This function returns key vault url
"""
return os.environ.get("KeyVault__Url")

# managed identity id
def get_managed_identity_id() -> str:
"""
This function returns maanged identity id
"""
return os.environ.get("AIService__AzureSearchOptions__ManagedIdentity__ClientId")


def get_managed_identity_fqname() -> str:
"""
This function returns maanged identity name
"""
return os.environ.get("AIService__AzureSearchOptions__ManagedIdentity__FQName")


# function app details
def get_function_app_authresourceid() -> str:
"""
This function returns apps registration in microsoft entra id
"""
return os.environ.get("FunctionApp__AuthResourceId")


def get_function_app_end_point() -> str:
"""
This function returns function app endpoint
"""
return os.environ.get("FunctionApp__Endpoint")

def get_function_app_key() -> str:
"""
This function returns function app key
"""
return os.environ.get("FunctionApp__Key")

def get_function_app_compass_function() -> str:
"""
This function returns function app compass function name
"""
return os.environ.get("FunctionApp__Compass__FunctionName")


def get_function_app_pre_embedding_cleaner_function() -> str:
"""
This function returns function app data cleanup function name
"""
return os.environ.get("FunctionApp__PreEmbeddingCleaner__FunctionName")


def get_function_app_adi_function() -> str:
"""
This function returns function app adi name
"""
return os.environ.get("FunctionApp__DocumentIntelligence__FunctionName")


def get_function_app_custom_split_function() -> str:
"""
This function returns function app adi name
"""
return os.environ.get("FunctionApp__CustomTextSplit__FunctionName")


def get_function_app_keyphrase_extractor_function() -> str:
"""
This function returns function app keyphrase extractor name
"""
return os.environ.get("FunctionApp__KeyphraseExtractor__FunctionName")


def get_function_app_ocr_function() -> str:
"""
This function returns function app ocr name
"""
return os.environ.get("FunctionApp__Ocr__FunctionName")


# search
def get_search_endpoint() -> str:
"""
This function returns azure ai search service endpoint
"""
return os.environ.get("AIService__AzureSearchOptions__Endpoint")


def get_search_user_assigned_identity() -> str:
"""
This function returns azure ai search service endpoint
"""
return os.environ.get("AIService__AzureSearchOptions__UserAssignedIdentity")


def get_search_key(client) -> str:
"""
This function returns azure ai search service admin key
"""
search_service_key_secret_name = str(os.environ.get("AIService__AzureSearchOptions__name")) + "-PrimaryKey"
retrieved_secret = client.get_secret(search_service_key_secret_name)
return retrieved_secret.value

def get_search_key_secret() -> str:
"""
This function returns azure ai search service admin key
"""
return os.environ.get("AIService__AzureSearchOptions__Key__Secret")


def get_search_embedding_model_dimensions(indexer_type: IndexerType) -> str:
"""
This function returns dimensions for embedding model
"""

normalised_indexer_type = (
indexer_type.value.replace("-", " ").title().replace(" ", "")
)

return os.environ.get(
f"AIService__AzureSearchOptions__{normalised_indexer_type}__EmbeddingDimensions"
)

def get_blob_connection_string() -> str:
"""
This function returns azure blob storage connection string
"""
return os.environ.get("StorageAccount__ConnectionString")

def get_fq_blob_connection_string() -> str:
"""
This function returns azure blob storage connection string
"""
return os.environ.get("StorageAccount__FQEndpoint")


def get_blob_container_name(indexer_type: str) -> str:
"""
This function returns azure blob container name
"""
normalised_indexer_type = (
indexer_type.value.replace("-", " ").title().replace(" ", "")
)
return os.environ.get(f"StorageAccount__{normalised_indexer_type}__Container")


def get_custom_skill_function_url(skill_type: str):
"""
Get the function app url that is hosting the custom skill
"""
url = (
get_function_app_end_point()
+ "/api/function_name?code="
+ get_function_app_key()
)
if skill_type == "compass":
url = url.replace("function_name", get_function_app_compass_function())
elif skill_type == "pre_embedding_cleaner":
url = url.replace(
"function_name", get_function_app_pre_embedding_cleaner_function()
)
elif skill_type == "adi":
url = url.replace("function_name", get_function_app_adi_function())
elif skill_type == "split":
url = url.replace("function_name", get_function_app_custom_split_function())
elif skill_type == "keyphraseextraction":
url = url.replace(
"function_name", get_function_app_keyphrase_extractor_function()
)
elif skill_type == "ocr":
url = url.replace("function_name", get_function_app_ocr_function())

return url
127 changes: 127 additions & 0 deletions ai_search_with_adi/function_apps/common/ai_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
from azure.search.documents.indexes.aio import SearchIndexerClient, SearchIndexClient
from azure.search.documents.aio import SearchClient
from azure.search.documents.indexes.models import SynonymMap
from azure.identity import DefaultAzureCredential
from azure.core.exceptions import HttpResponseError
import logging
import os
from enum import Enum
from openai import AsyncAzureOpenAI
from azure.search.documents.models import VectorizedQuery


class IndexerStatusEnum(Enum):
RETRIGGER = "RETRIGGER"
RUNNING = "RUNNING"
SUCCESS = "SUCCESS"


class AISearchHelper:
def __init__(self):
self._client_id = os.environ["FunctionApp__ClientId"]

self._endpoint = os.environ["AIService__AzureSearchOptions__Endpoint"]

async def get_index_client(self):
credential = DefaultAzureCredential(managed_identity_client_id=self._client_id)

return SearchIndexClient(self._endpoint, credential)

async def get_indexer_client(self):
credential = DefaultAzureCredential(managed_identity_client_id=self._client_id)

return SearchIndexerClient(self._endpoint, credential)

async def get_search_client(self, index_name):
credential = DefaultAzureCredential(managed_identity_client_id=self._client_id)

return SearchClient(self._endpoint, index_name, credential)

async def upload_synonym_map(self, synonym_map_name: str, synonyms: str):
index_client = await self.get_index_client()
async with index_client:
try:
await index_client.delete_synonym_map(synonym_map_name)
except HttpResponseError as e:
logging.error("Unable to delete synonym map %s", e)

logging.info("Synonyms: %s", synonyms)
synonym_map = SynonymMap(name=synonym_map_name, synonyms=synonyms)
await index_client.create_synonym_map(synonym_map)

async def get_indexer_status(self, indexer_name):
indexer_client = await self.get_indexer_client()
async with indexer_client:
try:
status = await indexer_client.get_indexer_status(indexer_name)

last_execution_result = status.last_result

if last_execution_result.status == "inProgress":
return IndexerStatusEnum.RUNNING, last_execution_result.start_time
elif last_execution_result.status in ["success", "transientFailure"]:
return IndexerStatusEnum.SUCCESS, last_execution_result.start_time
else:
return IndexerStatusEnum.RETRIGGER, last_execution_result.start_time
except HttpResponseError as e:
logging.error("Unable to get indexer status %s", e)

async def trigger_indexer(self, indexer_name):
indexer_client = await self.get_indexer_client()
async with indexer_client:
try:
await indexer_client.run_indexer(indexer_name)
except HttpResponseError as e:
logging.error("Unable to run indexer %s", e)

async def search_index(
self, index_name, semantic_config, search_text, deal_id=None
):
"""Search the index using the provided search text."""
async with AsyncAzureOpenAI(
# This is the default and can be omitted
api_key=os.environ["AIService__Compass_Key"],
azure_endpoint=os.environ["AIService__Compass_Endpoint"],
api_version="2023-03-15-preview",
) as open_ai_client:
embeddings = await open_ai_client.embeddings.create(
model=os.environ["AIService__Compass_Models__Embedding"],
input=search_text,
)

# Extract the embedding vector
embedding_vector = embeddings.data[0].embedding

vector_query = VectorizedQuery(
vector=embedding_vector,
k_nearest_neighbors=5,
fields="ChunkEmbedding",
)

if deal_id:
filter_expression = f"DealId eq '{deal_id}'"
else:
filter_expression = None

logging.info(f"Filter Expression: {filter_expression}")

search_client = await self.get_search_client(index_name)
async with search_client:
results = await search_client.search(
top=3,
query_type="semantic",
semantic_configuration_name=semantic_config,
search_text=search_text,
select="Title,Chunk",
vector_queries=[vector_query],
filter=filter_expression,
)

documents = [
document
async for result in results.by_page()
async for document in result
]

logging.info(f"Documents: {documents}")
return documents
Loading
Loading