Skip to content

Commit 97f32a6

Browse files
committed
Add openai setting
1 parent a06909d commit 97f32a6

File tree

5 files changed

+155
-57
lines changed

5 files changed

+155
-57
lines changed

ai_search_with_adi/ai_search/.env

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
FunctionApp__Endpoint=<functionAppEndpoint>
2+
FunctionApp__Key=<functionAppKey>
3+
FunctionApp__PreEmbeddingCleaner__FunctionName=pre_embedding_cleaner
4+
FunctionApp__ADI__FunctionName=adi_2_ai_search
5+
FunctionApp__KeyPhraseExtractor__FunctionName=keyphrase_extractor
6+
FunctionApp__AppRegistrationResourceId=<App registration in form api://appRegistrationclientId if using identity based connections>
7+
AIService__AzureSearchOptions__IdentityType=<identityType> # system_assigned or user_assigned or key
8+
AIService__AzureSearchOptions__Endpoint=<searchServiceEndpoint>
9+
AIService__AzureSearchOptions__Identity__ClientId=<clientId if using user assigned identity>
10+
AIService__AzureSearchOptions__Key=<searchServiceKey if not using identity>
11+
AIService__AzureSearchOptions__UsePrivateEndpoint=<true/false>
12+
AIService__AzureSearchOptions__Identity__FQName=<fully qualified name of the identity if using user assigned identity>
13+
StorageAccount__FQEndpoint=<Fully qualified endpoint in form ResourceId=resourceId if using identity based connections>
14+
StorageAccount__ConnectionString=<connectionString if using non managed identity>
15+
StorageAccount__RagDocuments__Container=<containerName>
16+
OpenAI__ApiKey=<openAIKey if using non managed identity>
17+
OpenAI__Endpoint=<openAIEndpoint>
18+
OpenAI__EmbeddingModel=<openAIEmbeddingModelName>
19+
OpenAI__EmbeddingDeployment=<openAIEmbeddingDeploymentId>
20+
OpenAI__EmbeddingDimensions=1536

ai_search_with_adi/ai_search/ai_search.py

+28-7
Original file line numberDiff line numberDiff line change
@@ -307,13 +307,13 @@ def get_adi_skill(self, chunk_by_page=False) -> WebApiSkill:
307307
)
308308

309309
if self.environment.identity_type != IdentityType.KEY:
310-
adi_skill.auth_identity = self.environment.function_app_app_registration_resource_id
311-
312-
if self.environment.identity_type == IdentityType.USER_ASSIGNED:
313310
adi_skill.auth_identity = (
314-
self.environment.ai_search_user_assigned_identity
311+
self.environment.function_app_app_registration_resource_id
315312
)
316313

314+
if self.environment.identity_type == IdentityType.USER_ASSIGNED:
315+
adi_skill.auth_identity = self.environment.ai_search_user_assigned_identity
316+
317317
return adi_skill
318318

319319
def get_vector_skill(
@@ -335,12 +335,20 @@ def get_vector_skill(
335335
name="Vector Skill",
336336
description="Skill to generate embeddings",
337337
context=context,
338-
deployment_id="0",
339-
model_name="text-embedding-3-large",
338+
deployment_id=self.environment.open_ai_embedding_deployment,
339+
model_name=self.environment.open_ai_embedding_model,
340340
inputs=embedding_skill_inputs,
341341
outputs=embedding_skill_outputs,
342+
dimensions=self.environment.open_ai_embedding_dimensions,
342343
)
343344

345+
if self.environment.identity_type == IdentityType.KEY:
346+
vector_skill.api_key = self.environment.open_ai_api_key
347+
elif self.environment.identity_type == IdentityType.USER_ASSIGNED:
348+
vector_skill.auth_identity = (
349+
self.environment.ai_search_user_assigned_identity
350+
)
351+
344352
return vector_skill
345353

346354
def get_key_phrase_extraction_skill(self, context, source) -> WebApiSkill:
@@ -403,6 +411,19 @@ def get_vector_search(self) -> VectorSearch:
403411
VectorSearch: The vector search configuration
404412
"""
405413

414+
open_ai_params = AzureOpenAIParameters(
415+
resource_uri=self.environment.open_ai_endpoint,
416+
modelName=self.environment.open_ai_embedding_model,
417+
deploymentId=self.environment.open_ai_embedding_deployment,
418+
)
419+
420+
if self.environment.identity_type == IdentityType.KEY:
421+
open_ai_params.api_key = self.environment.open_ai_api_key
422+
elif self.environment.identity_type == IdentityType.USER_ASSIGNED:
423+
open_ai_params.auth_identity = (
424+
self.environment.ai_search_user_assigned_identity
425+
)
426+
406427
vector_search = VectorSearch(
407428
algorithms=[
408429
HnswAlgorithmConfiguration(name=self.algorithm_name),
@@ -417,7 +438,7 @@ def get_vector_search(self) -> VectorSearch:
417438
vectorizers=[
418439
AzureOpenAIVectorizer(
419440
name=self.vectorizer_name,
420-
azure_open_ai_parameters=AzureOpenAIParameters(),
441+
azure_open_ai_parameters=open_ai_params,
421442
),
422443
],
423444
)

ai_search_with_adi/ai_search/environment.py

+83-28
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,24 @@
77
from azure.core.credentials import AzureKeyCredential
88
from azure.search.documents.indexes.models import SearchIndexerDataUserAssignedIdentity
99

10+
1011
class IndexerType(Enum):
1112
"""The type of the indexer"""
1213

1314
RAG_DOCUMENTS = "rag-documents"
1415

16+
1517
class IdentityType(Enum):
1618
"""The type of the indexer"""
1719

1820
USER_ASSIGNED = "user_assigned"
1921
SYSTEM_ASSIGNED = "system_assigned"
2022
KEY = "key"
2123

24+
2225
class AISearchEnvironment:
2326
"""This class is used to get the environment variables for the AI search service."""
27+
2428
def __init__(self, indexer_type: IndexerType):
2529
"""Initialize the AISearchEnvironment class.
2630
@@ -33,7 +37,7 @@ def __init__(self, indexer_type: IndexerType):
3337
@property
3438
def normalised_indexer_type(self) -> str:
3539
"""This function returns the normalised indexer type.
36-
40+
3741
Returns:
3842
str: The normalised indexer type
3943
"""
@@ -46,7 +50,7 @@ def normalised_indexer_type(self) -> str:
4650
@property
4751
def identity_type(self) -> IdentityType:
4852
"""This function returns the identity type.
49-
53+
5054
Returns:
5155
IdentityType: The identity type
5256
"""
@@ -60,54 +64,100 @@ def identity_type(self) -> IdentityType:
6064
return IdentityType.KEY
6165
else:
6266
raise ValueError("Invalid identity type")
63-
67+
6468
@property
6569
def ai_search_endpoint(self) -> str:
6670
"""This function returns the ai search endpoint.
67-
71+
6872
Returns:
6973
str: The ai search endpoint
7074
"""
7175
return os.environ.get("AIService__AzureSearchOptions__Endpoint")
72-
76+
7377
@property
7478
def ai_search_identity_id(self) -> str:
7579
"""This function returns the ai search identity id.
76-
80+
7781
Returns:
7882
str: The ai search identity id
7983
"""
8084
return os.environ.get("AIService__AzureSearchOptions__Identity__ClientId")
81-
85+
8286
@property
8387
def ai_search_user_assigned_identity(self) -> SearchIndexerDataUserAssignedIdentity:
8488
"""This function returns the ai search user assigned identity.
85-
89+
8690
Returns:
87-
SearchIndexerDataUserAssignedIdentity: The ai search user assigned identity"""
91+
SearchIndexerDataUserAssignedIdentity: The ai search user assigned identity
92+
"""
8893
user_assigned_identity = SearchIndexerDataUserAssignedIdentity(
89-
user_assigned_identity=os.environ.get("AIService__AzureSearchOptions__Identity__FQName")
94+
user_assigned_identity=os.environ.get(
95+
"AIService__AzureSearchOptions__Identity__FQName"
96+
)
9097
)
9198
return user_assigned_identity
9299

93100
@property
94101
def ai_search_credential(self) -> DefaultAzureCredential | AzureKeyCredential:
95102
"""This function returns the ai search credential.
96-
103+
97104
Returns:
98105
DefaultAzureCredential | AzureKeyCredential: The ai search credential
99106
"""
100107
if self.identity_type in IdentityType.SYSTEM_ASSIGNED:
101108
return DefaultAzureCredential()
102109
elif self.identity_type in IdentityType.USER_ASSIGNED:
103-
return DefaultAzureCredential(managed_identity_client_id=self.ai_search_identity_id)
110+
return DefaultAzureCredential(
111+
managed_identity_client_id=self.ai_search_identity_id
112+
)
104113
else:
105-
return AzureKeyCredential(os.environ.get("AIService__AzureSearchOptions__Key"))
114+
return AzureKeyCredential(
115+
os.environ.get("AIService__AzureSearchOptions__Key")
116+
)
117+
118+
@property
119+
def open_ai_api_key(self) -> str:
120+
"""This function returns the open ai api key.
121+
122+
Returns:
123+
str: The open ai api key
124+
"""
125+
return os.environ.get("OpenAI__ApiKey")
126+
127+
@property
128+
def open_ai_endpoint(self) -> str:
129+
"""This function returns the open ai endpoint.
130+
131+
Returns:
132+
str: The open ai endpoint
133+
"""
134+
return os.environ.get("OpenAI__Endpoint")
135+
136+
@property
137+
def open_ai_embedding_model(self) -> str:
138+
"""This function returns the open ai embedding model.
139+
140+
Returns:
141+
str: The open ai embedding model
142+
"""
143+
return os.environ.get("OpenAI__EmbeddingModel")
144+
145+
@property
146+
def open_ai_embedding_deployment(self) -> str:
147+
"""This function returns the open ai embedding deployment.
148+
149+
Returns:
150+
str: The open ai embedding deployment
151+
"""
152+
return os.environ.get("OpenAI__EmbeddingDeployment")
106153

107154
@property
108155
def storage_account_connection_string(self) -> str:
109156
"""This function returns the blob connection string. If the identity type is user_assigned or system_assigned, it returns the FQEndpoint, otherwise it returns the ConnectionString"""
110-
if self.identity_type in [IdentityType.SYSTEM_ASSIGNED, IdentityType.USER_ASSIGNED]:
157+
if self.identity_type in [
158+
IdentityType.SYSTEM_ASSIGNED,
159+
IdentityType.USER_ASSIGNED,
160+
]:
111161
return os.environ.get("StorageAccount__FQEndpoint")
112162
else:
113163
return os.environ.get("StorageAccount__ConnectionString")
@@ -118,8 +168,10 @@ def storage_account_blob_container_name(self) -> str:
118168
This function returns azure blob container name
119169
"""
120170

121-
return os.environ.get(f"StorageAccount__{self.normalised_indexer_type}__Container")
122-
171+
return os.environ.get(
172+
f"StorageAccount__{self.normalised_indexer_type}__Container"
173+
)
174+
123175
@property
124176
def function_app_end_point(self) -> str:
125177
"""
@@ -133,14 +185,14 @@ def function_app_key(self) -> str:
133185
This function returns function app key
134186
"""
135187
return os.environ.get("FunctionApp__Key")
136-
188+
137189
@property
138190
def function_app_app_registration_resource_id(self) -> str:
139191
"""
140192
This function returns function app app registration resource id
141193
"""
142194
return os.environ.get("FunctionApp__AppRegistrationResourceId")
143-
195+
144196
@property
145197
def function_app_pre_embedding_cleaner_route(self) -> str:
146198
"""
@@ -161,26 +213,27 @@ def function_app_key_phrase_extractor_route(self) -> str:
161213
This function returns function app keyphrase extractor name
162214
"""
163215
return os.environ.get("FunctionApp__KeyPhraseExtractor__FunctionName")
164-
216+
165217
@property
166-
def ai_search_embedding_model_dimensions(self) -> str:
218+
def open_ai_embedding_dimensions(self) -> str:
167219
"""
168220
This function returns dimensions for embedding model.
169221
170222
Returns:
171223
str: The dimensions for embedding model
172224
"""
173225

174-
return os.environ.get(
175-
f"AIService__AzureSearchOptions__{self.normalised_indexer_type}__EmbeddingDimensions"
176-
)
177-
226+
return os.environ.get("OpenAI__EmbeddingDimensions")
227+
178228
@property
179229
def use_private_endpoint(self) -> bool:
180230
"""
181231
This function returns true if private endpoint is used
182232
"""
183-
return os.environ.get("AIService__AzureSearchOptions__UsePrivateEndpoint") == "true"
233+
return (
234+
os.environ.get("AIService__AzureSearchOptions__UsePrivateEndpoint")
235+
== "true"
236+
)
184237

185238
def get_custom_skill_function_url(self, skill_type: str):
186239
"""
@@ -194,7 +247,9 @@ def get_custom_skill_function_url(self, skill_type: str):
194247
route = self.function_app_key_phrase_extractor_route
195248
else:
196249
raise ValueError(f"Invalid skill type: {skill_type}")
197-
198-
full_url = f"{self.function_app_end_point}/api/{route}?code={self.function_app_key}"
199250

200-
return full_url
251+
full_url = (
252+
f"{self.function_app_end_point}/api/{route}?code={self.function_app_key}"
253+
)
254+
255+
return full_url

ai_search_with_adi/ai_search/rag_documents.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def get_index_fields(self) -> list[SearchableField]:
8383
SearchField(
8484
name="ChunkEmbedding",
8585
type=SearchFieldDataType.Collection(SearchFieldDataType.Single),
86-
vector_search_dimensions=self.environment.ai_search_embedding_model_dimensions,
86+
vector_search_dimensions=self.environment.open_ai_embedding_dimensions,
8787
vector_search_profile_name=self.vector_search_profile_name,
8888
),
8989
SearchableField(

0 commit comments

Comments
 (0)