diff --git a/packages/kbn-doc-links/src/get_doc_links.ts b/packages/kbn-doc-links/src/get_doc_links.ts index 8234daa4b4454..42070ca3053f1 100644 --- a/packages/kbn-doc-links/src/get_doc_links.ts +++ b/packages/kbn-doc-links/src/get_doc_links.ts @@ -520,6 +520,7 @@ export const getDocLinks = ({ kibanaBranch }: GetDocLinkOptions): DocLinks => { trainedModels: `${MACHINE_LEARNING_DOCS}ml-trained-models.html`, startTrainedModelsDeployment: `${MACHINE_LEARNING_DOCS}ml-nlp-deploy-model.html`, nlpElser: `${MACHINE_LEARNING_DOCS}ml-nlp-elser.html`, + nlpE5: `${MACHINE_LEARNING_DOCS}ml-nlp-e5.html`, nlpImportModel: `${MACHINE_LEARNING_DOCS}ml-nlp-import-model.html`, }, transforms: { diff --git a/x-pack/packages/ml/trained_models_utils/index.ts b/x-pack/packages/ml/trained_models_utils/index.ts index 0ae43c5ef4013..b9ad2e2ae4d4e 100644 --- a/x-pack/packages/ml/trained_models_utils/index.ts +++ b/x-pack/packages/ml/trained_models_utils/index.ts @@ -19,7 +19,8 @@ export { type ModelDefinition, type ModelDefinitionResponse, type ElserVersion, - type GetElserOptions, + type GetModelDownloadConfigOptions, + type ElasticCuratedModelName, ELSER_ID_V1, ELASTIC_MODEL_TAG, ELASTIC_MODEL_TYPE, diff --git a/x-pack/packages/ml/trained_models_utils/src/constants/trained_models.ts b/x-pack/packages/ml/trained_models_utils/src/constants/trained_models.ts index 917e7b5cac3e4..10e6618672f49 100644 --- a/x-pack/packages/ml/trained_models_utils/src/constants/trained_models.ts +++ b/x-pack/packages/ml/trained_models_utils/src/constants/trained_models.ts @@ -61,6 +61,7 @@ export const ELASTIC_MODEL_DEFINITIONS: Record = Object description: i18n.translate('xpack.ml.trainedModels.modelsList.elserDescription', { defaultMessage: 'Elastic Learned Sparse EncodeR v1 (Tech Preview)', }), + type: ['elastic', 'pytorch', 'text_expansion'], }, '.elser_model_2': { modelName: 'elser', @@ -74,6 +75,7 @@ export const ELASTIC_MODEL_DEFINITIONS: Record = Object description: i18n.translate('xpack.ml.trainedModels.modelsList.elserV2Description', { defaultMessage: 'Elastic Learned Sparse EncodeR v2', }), + type: ['elastic', 'pytorch', 'text_expansion'], }, '.elser_model_2_linux-x86_64': { modelName: 'elser', @@ -88,14 +90,49 @@ export const ELASTIC_MODEL_DEFINITIONS: Record = Object description: i18n.translate('xpack.ml.trainedModels.modelsList.elserV2x86Description', { defaultMessage: 'Elastic Learned Sparse EncodeR v2, optimized for linux-x86_64', }), + type: ['elastic', 'pytorch', 'text_expansion'], + }, + '.multilingual-e5-small': { + modelName: 'e5', + version: 1, + default: true, + config: { + input: { + field_names: ['text_field'], + }, + }, + description: i18n.translate('xpack.ml.trainedModels.modelsList.e5v1Description', { + defaultMessage: 'E5 (EmbEddings from bidirEctional Encoder rEpresentations)', + }), + license: 'MIT', + type: ['pytorch', 'text_embedding'], + }, + '.multilingual-e5-small_linux-x86_64': { + modelName: 'e5', + version: 1, + os: 'Linux', + arch: 'amd64', + config: { + input: { + field_names: ['text_field'], + }, + }, + description: i18n.translate('xpack.ml.trainedModels.modelsList.e5v1x86Description', { + defaultMessage: + 'E5 (EmbEddings from bidirEctional Encoder rEpresentations), optimized for linux-x86_64', + }), + license: 'MIT', + type: ['pytorch', 'text_embedding'], }, } as const); +export type ElasticCuratedModelName = 'elser' | 'e5'; + export interface ModelDefinition { /** * Model name, e.g. elser */ - modelName: string; + modelName: ElasticCuratedModelName; version: number; /** * Default PUT model configuration @@ -107,13 +144,15 @@ export interface ModelDefinition { default?: boolean; recommended?: boolean; hidden?: boolean; + license?: string; + type?: readonly string[]; } export type ModelDefinitionResponse = ModelDefinition & { /** * Complete model id, e.g. .elser_model_2_linux-x86_64 */ - name: string; + model_id: string; }; export type ElasticModelId = keyof typeof ELASTIC_MODEL_DEFINITIONS; @@ -129,6 +168,6 @@ export type ModelState = typeof MODEL_STATE[keyof typeof MODEL_STATE] | null; export type ElserVersion = 1 | 2; -export interface GetElserOptions { +export interface GetModelDownloadConfigOptions { version?: ElserVersion; } diff --git a/x-pack/plugins/elastic_assistant/server/plugin.ts b/x-pack/plugins/elastic_assistant/server/plugin.ts index 827b428c97803..a0df339695885 100755 --- a/x-pack/plugins/elastic_assistant/server/plugin.ts +++ b/x-pack/plugins/elastic_assistant/server/plugin.ts @@ -80,7 +80,7 @@ export class ElasticAssistantPlugin const getElserId: GetElser = once( async (request: KibanaRequest, savedObjectsClient: SavedObjectsClientContract) => { return (await plugins.ml.trainedModelsProvider(request, savedObjectsClient).getELSER()) - .name; + .model_id; } ); diff --git a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/text_expansion_callout/text_expansion_callout_logic.ts b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/text_expansion_callout/text_expansion_callout_logic.ts index 06d4f553bbabd..35544bc5d5685 100644 --- a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/text_expansion_callout/text_expansion_callout_logic.ts +++ b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/text_expansion_callout/text_expansion_callout_logic.ts @@ -179,7 +179,7 @@ export const TextExpansionCalloutLogic = kea< afterMount: async () => { const elserModel = await KibanaLogic.values.ml.elasticModels?.getELSER({ version: 2 }); if (elserModel != null) { - actions.setElserModelId(elserModel.name); + actions.setElserModelId(elserModel.model_id); actions.fetchTextExpansionModel(); } }, diff --git a/x-pack/plugins/ml/public/application/model_management/add_model_flyout.tsx b/x-pack/plugins/ml/public/application/model_management/add_model_flyout.tsx index 267b00b95cca2..6c69446b8725e 100644 --- a/x-pack/plugins/ml/public/application/model_management/add_model_flyout.tsx +++ b/x-pack/plugins/ml/public/application/model_management/add_model_flyout.tsx @@ -42,52 +42,52 @@ export interface AddModelFlyoutProps { onSubmit: (modelId: string) => void; } +type FlyoutTabId = 'clickToDownload' | 'manualDownload'; + /** * Flyout for downloading elastic curated models and showing instructions for importing third-party models. */ export const AddModelFlyout: FC = ({ onClose, onSubmit, modelDownloads }) => { const canCreateTrainedModels = usePermissionCheck('canCreateTrainedModels'); - const isElserTabVisible = canCreateTrainedModels && modelDownloads.length > 0; + const isClickToDownloadTabVisible = canCreateTrainedModels && modelDownloads.length > 0; - const [selectedTabId, setSelectedTabId] = useState(isElserTabVisible ? 'elser' : 'thirdParty'); + const [selectedTabId, setSelectedTabId] = useState( + isClickToDownloadTabVisible ? 'clickToDownload' : 'manualDownload' + ); const tabs = useMemo(() => { return [ - ...(isElserTabVisible + ...(isClickToDownloadTabVisible ? [ { - id: 'elser', + id: 'clickToDownload' as const, name: ( - - - - - - - - + ), content: ( - + ), }, ] : []), { - id: 'thirdParty', + id: 'manualDownload' as const, name: ( ), - content: , + content: , }, ]; - }, [isElserTabVisible, modelDownloads, onSubmit]); + }, [isClickToDownloadTabVisible, modelDownloads, onSubmit]); const selectedTabContent = useMemo(() => { return tabs.find((obj) => obj.id === selectedTabId)?.content; @@ -133,15 +133,18 @@ export const AddModelFlyout: FC = ({ onClose, onSubmit, mod ); }; -interface ElserTabContentProps { +interface ClickToDownloadTabContentProps { modelDownloads: ModelItem[]; onModelDownload: (modelId: string) => void; } /** - * ELSER tab content for selecting a model to download. + * Tab content for selecting a model to download. */ -const ElserTabContent: FC = ({ modelDownloads, onModelDownload }) => { +const ClickToDownloadTabContent: FC = ({ + modelDownloads, + onModelDownload, +}) => { const { services: { docLinks }, } = useMlKibana(); @@ -157,26 +160,33 @@ const ElserTabContent: FC = ({ modelDownloads, onModelDown {modelName === 'elser' ? (
- -

- -

-
+ + + + + + +

+ +

+
+
+

- + = ({ modelDownloads, onModelDown

) : null} + {modelName === 'e5' ? ( +
+ +

+ +

+
+ +

+ + + +

+ + + + + + + + + + + + + + +
+ ) : null} + = ({ modelDownloads, onModelDown ), }} > - {models.map((model) => { + {models.map((model, index) => { return ( = ({ modelDownloads, onModelDown checked={model.model_id === selectedModelId} onChange={setSelectedModelId.bind(null, model.model_id)} /> - + {index < models.length - 1 ? : null} ); })} +
); })} @@ -279,9 +336,9 @@ const ElserTabContent: FC = ({ modelDownloads, onModelDown }; /** - * Third-party tab content for showing instructions for importing third-party models. + * Manual download tab content for showing instructions for importing third-party models. */ -const ThirdPartyTabContent: FC = () => { +const ManualDownloadTabContent: FC = () => { const { services: { docLinks }, } = useMlKibana(); diff --git a/x-pack/plugins/ml/public/application/model_management/models_list.tsx b/x-pack/plugins/ml/public/application/model_management/models_list.tsx index 6d9dda39e2853..0fc27fcb33fd4 100644 --- a/x-pack/plugins/ml/public/application/model_management/models_list.tsx +++ b/x-pack/plugins/ml/public/application/model_management/models_list.tsx @@ -262,17 +262,17 @@ export const ModelsList: FC = ({ ); const forDownload = await trainedModelsApiService.getTrainedModelDownloads(); const notDownloaded: ModelItem[] = forDownload - .filter(({ name, hidden, recommended }) => { - if (recommended && idMap.has(name)) { - idMap.get(name)!.recommended = true; + .filter(({ model_id: modelId, hidden, recommended }) => { + if (recommended && idMap.has(modelId)) { + idMap.get(modelId)!.recommended = true; } - return !idMap.has(name) && !hidden; + return !idMap.has(modelId) && !hidden; }) .map((modelDefinition) => { return { - model_id: modelDefinition.name, - type: [ELASTIC_MODEL_TYPE], - tags: [ELASTIC_MODEL_TAG], + model_id: modelDefinition.model_id, + type: modelDefinition.type, + tags: modelDefinition.type?.includes(ELASTIC_MODEL_TAG) ? [ELASTIC_MODEL_TAG] : [], putModelConfig: modelDefinition.config, description: modelDefinition.description, state: MODEL_STATE.NOT_DOWNLOADED, diff --git a/x-pack/plugins/ml/public/application/services/elastic_models_service.ts b/x-pack/plugins/ml/public/application/services/elastic_models_service.ts index 2591fb6d82e7d..efc6249f9582b 100644 --- a/x-pack/plugins/ml/public/application/services/elastic_models_service.ts +++ b/x-pack/plugins/ml/public/application/services/elastic_models_service.ts @@ -5,7 +5,10 @@ * 2.0. */ -import type { ModelDefinitionResponse, GetElserOptions } from '@kbn/ml-trained-models-utils'; +import type { + ModelDefinitionResponse, + GetModelDownloadConfigOptions, +} from '@kbn/ml-trained-models-utils'; import { type TrainedModelsApiService } from './ml_api_service/trained_models'; export class ElasticModels { @@ -17,7 +20,7 @@ export class ElasticModels { * If any of the ML nodes run a different OS rather than Linux, or the CPU architecture isn't x86_64, * a portable version of the model is returned. */ - public async getELSER(options?: GetElserOptions): Promise { + public async getELSER(options?: GetModelDownloadConfigOptions): Promise { return await this.trainedModels.getElserConfig(options); } } diff --git a/x-pack/plugins/ml/public/application/services/ml_api_service/trained_models.ts b/x-pack/plugins/ml/public/application/services/ml_api_service/trained_models.ts index b886f6f7df8e5..9bf880fb2b312 100644 --- a/x-pack/plugins/ml/public/application/services/ml_api_service/trained_models.ts +++ b/x-pack/plugins/ml/public/application/services/ml_api_service/trained_models.ts @@ -11,7 +11,10 @@ import type { IngestPipeline } from '@elastic/elasticsearch/lib/api/types'; import { useMemo } from 'react'; import type { HttpFetchQuery } from '@kbn/core/public'; import type { ErrorType } from '@kbn/ml-error-utils'; -import type { GetElserOptions, ModelDefinitionResponse } from '@kbn/ml-trained-models-utils'; +import type { + GetModelDownloadConfigOptions, + ModelDefinitionResponse, +} from '@kbn/ml-trained-models-utils'; import { ML_INTERNAL_BASE_PATH } from '../../../../common/constants/app'; import type { MlSavedObjectType } from '../../../../common/types/saved_objects'; import { HttpService } from '../http_service'; @@ -73,7 +76,7 @@ export function trainedModelsApiProvider(httpService: HttpService) { /** * Gets ELSER config for download based on the cluster OS and CPU architecture. */ - getElserConfig(options?: GetElserOptions) { + getElserConfig(options?: GetModelDownloadConfigOptions) { return httpService.http({ path: `${ML_INTERNAL_BASE_PATH}/trained_models/elser_config`, method: 'GET', diff --git a/x-pack/plugins/ml/public/mocks.ts b/x-pack/plugins/ml/public/mocks.ts index be18bfb1f49f1..8a2c7efefc9c5 100644 --- a/x-pack/plugins/ml/public/mocks.ts +++ b/x-pack/plugins/ml/public/mocks.ts @@ -20,7 +20,7 @@ const createElasticModelsMock = (): jest.Mocked => { }, }, description: 'Elastic Learned Sparse EncodeR v2 (Tech Preview)', - name: '.elser_model_2', + model_id: '.elser_model_2', }), } as unknown as jest.Mocked; }; diff --git a/x-pack/plugins/ml/server/models/model_management/model_provider.test.ts b/x-pack/plugins/ml/server/models/model_management/model_provider.test.ts index 5267a95d4fb48..ff18327fdea5e 100644 --- a/x-pack/plugins/ml/server/models/model_management/model_provider.test.ts +++ b/x-pack/plugins/ml/server/models/model_management/model_provider.test.ts @@ -54,27 +54,53 @@ describe('modelsProvider', () => { config: { input: { field_names: ['text_field'] } }, description: 'Elastic Learned Sparse EncodeR v1 (Tech Preview)', hidden: true, - name: '.elser_model_1', + model_id: '.elser_model_1', version: 1, modelName: 'elser', + type: ['elastic', 'pytorch', 'text_expansion'], }, { config: { input: { field_names: ['text_field'] } }, default: true, description: 'Elastic Learned Sparse EncodeR v2', - name: '.elser_model_2', + model_id: '.elser_model_2', version: 2, modelName: 'elser', + type: ['elastic', 'pytorch', 'text_expansion'], }, { arch: 'amd64', config: { input: { field_names: ['text_field'] } }, description: 'Elastic Learned Sparse EncodeR v2, optimized for linux-x86_64', - name: '.elser_model_2_linux-x86_64', + model_id: '.elser_model_2_linux-x86_64', os: 'Linux', recommended: true, version: 2, modelName: 'elser', + type: ['elastic', 'pytorch', 'text_expansion'], + }, + { + config: { input: { field_names: ['text_field'] } }, + description: 'E5 (EmbEddings from bidirEctional Encoder rEpresentations)', + model_id: '.multilingual-e5-small', + default: true, + version: 1, + modelName: 'e5', + license: 'MIT', + type: ['pytorch', 'text_embedding'], + }, + { + arch: 'amd64', + config: { input: { field_names: ['text_field'] } }, + description: + 'E5 (EmbEddings from bidirEctional Encoder rEpresentations), optimized for linux-x86_64', + model_id: '.multilingual-e5-small_linux-x86_64', + os: 'Linux', + recommended: true, + version: 1, + modelName: 'e5', + license: 'MIT', + type: ['pytorch', 'text_embedding'], }, ]); }); @@ -108,26 +134,51 @@ describe('modelsProvider', () => { config: { input: { field_names: ['text_field'] } }, description: 'Elastic Learned Sparse EncodeR v1 (Tech Preview)', hidden: true, - name: '.elser_model_1', + model_id: '.elser_model_1', version: 1, modelName: 'elser', + type: ['elastic', 'pytorch', 'text_expansion'], }, { config: { input: { field_names: ['text_field'] } }, recommended: true, description: 'Elastic Learned Sparse EncodeR v2', - name: '.elser_model_2', + model_id: '.elser_model_2', version: 2, modelName: 'elser', + type: ['elastic', 'pytorch', 'text_expansion'], }, { arch: 'amd64', config: { input: { field_names: ['text_field'] } }, description: 'Elastic Learned Sparse EncodeR v2, optimized for linux-x86_64', - name: '.elser_model_2_linux-x86_64', + model_id: '.elser_model_2_linux-x86_64', os: 'Linux', version: 2, modelName: 'elser', + type: ['elastic', 'pytorch', 'text_expansion'], + }, + { + config: { input: { field_names: ['text_field'] } }, + description: 'E5 (EmbEddings from bidirEctional Encoder rEpresentations)', + model_id: '.multilingual-e5-small', + recommended: true, + version: 1, + modelName: 'e5', + type: ['pytorch', 'text_embedding'], + license: 'MIT', + }, + { + arch: 'amd64', + config: { input: { field_names: ['text_field'] } }, + description: + 'E5 (EmbEddings from bidirEctional Encoder rEpresentations), optimized for linux-x86_64', + model_id: '.multilingual-e5-small_linux-x86_64', + os: 'Linux', + version: 1, + modelName: 'e5', + type: ['pytorch', 'text_embedding'], + license: 'MIT', }, ]); }); @@ -136,7 +187,7 @@ describe('modelsProvider', () => { describe('getELSER', () => { test('provides a recommended definition by default', async () => { const result = await modelService.getELSER(); - expect(result.name).toEqual('.elser_model_2_linux-x86_64'); + expect(result.model_id).toEqual('.elser_model_2_linux-x86_64'); }); test('provides a default version if there is no recommended', async () => { @@ -162,17 +213,50 @@ describe('modelsProvider', () => { }); const result = await modelService.getELSER(); - expect(result.name).toEqual('.elser_model_2'); + expect(result.model_id).toEqual('.elser_model_2'); }); test('provides the requested version', async () => { const result = await modelService.getELSER({ version: 1 }); - expect(result.name).toEqual('.elser_model_1'); + expect(result.model_id).toEqual('.elser_model_1'); }); test('provides the requested version of a recommended architecture', async () => { const result = await modelService.getELSER({ version: 2 }); - expect(result.name).toEqual('.elser_model_2_linux-x86_64'); + expect(result.model_id).toEqual('.elser_model_2_linux-x86_64'); + }); + }); + + describe('getCuratedModelConfig', () => { + test('provides a recommended definition by default', async () => { + const result = await modelService.getCuratedModelConfig('e5'); + expect(result.model_id).toEqual('.multilingual-e5-small_linux-x86_64'); + }); + + test('provides a default version if there is no recommended', async () => { + mockCloud.cloudId = undefined; + (mockClient.asInternalUser.transport.request as jest.Mock).mockResolvedValueOnce({ + _nodes: { + total: 1, + successful: 1, + failed: 0, + }, + cluster_name: 'default', + nodes: { + yYmqBqjpQG2rXsmMSPb9pQ: { + name: 'node-0', + roles: ['ml'], + attributes: {}, + os: { + name: 'Mac OS X', + arch: 'aarch64', + }, + }, + }, + }); + + const result = await modelService.getCuratedModelConfig('e5'); + expect(result.model_id).toEqual('.multilingual-e5-small'); }); }); }); diff --git a/x-pack/plugins/ml/server/models/model_management/models_provider.ts b/x-pack/plugins/ml/server/models/model_management/models_provider.ts index e6243b38324bf..6d3ba51a9b76b 100644 --- a/x-pack/plugins/ml/server/models/model_management/models_provider.ts +++ b/x-pack/plugins/ml/server/models/model_management/models_provider.ts @@ -19,10 +19,11 @@ import type { } from '@elastic/elasticsearch/lib/api/types'; import { ELASTIC_MODEL_DEFINITIONS, - type GetElserOptions, + type GetModelDownloadConfigOptions, type ModelDefinitionResponse, } from '@kbn/ml-trained-models-utils'; import type { CloudSetup } from '@kbn/cloud-plugin/server'; +import type { ElasticCuratedModelName } from '@kbn/ml-trained-models-utils'; import type { PipelineDefinition } from '../../../common/types/trained_models'; import type { MlClient } from '../../lib/ml_client'; import type { MLSavedObjectService } from '../../saved_objects'; @@ -52,6 +53,8 @@ interface ModelMapResult { error: null | any; } +export type GetCuratedModelConfigParams = Parameters; + export class ModelsProvider { private _transforms?: TransformGetTransformTransformSummary[]; @@ -410,8 +413,6 @@ export class ModelsProvider { } throw error; } - - return result; } /** @@ -460,7 +461,7 @@ export class ModelsProvider { const modelDefinitionMap = new Map(); - for (const [name, def] of Object.entries(ELASTIC_MODEL_DEFINITIONS)) { + for (const [modelId, def] of Object.entries(ELASTIC_MODEL_DEFINITIONS)) { const recommended = (isCloud && def.os === 'Linux' && def.arch === 'amd64') || (sameArch && !!def?.os && def?.os === osName && def?.arch === arch); @@ -470,7 +471,7 @@ export class ModelsProvider { const modelDefinitionResponse = { ...def, ...(recommended ? { recommended } : {}), - name, + model_id: modelId, }; if (modelDefinitionMap.has(modelName)) { @@ -494,14 +495,19 @@ export class ModelsProvider { } /** - * Provides an ELSER model name and configuration for download based on the current cluster architecture. - * The current default version is 2. If running on Cloud it returns the Linux x86_64 optimized version. - * If any of the ML nodes run a different OS rather than Linux, or the CPU architecture isn't x86_64, - * a portable version of the model is returned. + * Provides an appropriate model ID and configuration for download based on the current cluster architecture. + * + * @param modelName + * @param options + * @returns */ - async getELSER(options?: GetElserOptions): Promise | never { - const modelDownloadConfig = await this.getModelDownloads(); - + async getCuratedModelConfig( + modelName: ElasticCuratedModelName, + options?: GetModelDownloadConfigOptions + ): Promise | never { + const modelDownloadConfig = (await this.getModelDownloads()).filter( + (model) => model.modelName === modelName + ); let requestedModel: ModelDefinitionResponse | undefined; let recommendedModel: ModelDefinitionResponse | undefined; let defaultModel: ModelDefinitionResponse | undefined; @@ -527,6 +533,18 @@ export class ModelsProvider { return requestedModel || recommendedModel || defaultModel!; } + /** + * Provides an ELSER model name and configuration for download based on the current cluster architecture. + * The current default version is 2. If running on Cloud it returns the Linux x86_64 optimized version. + * If any of the ML nodes run a different OS rather than Linux, or the CPU architecture isn't x86_64, + * a portable version of the model is returned. + */ + async getELSER( + options?: GetModelDownloadConfigOptions + ): Promise | never { + return await this.getCuratedModelConfig('elser', options); + } + /** * Puts the requested ELSER model into elasticsearch, triggering elasticsearch to download the model. * Assigns the model to the * space. @@ -535,7 +553,7 @@ export class ModelsProvider { */ async installElasticModel(modelId: string, mlSavedObjectService: MLSavedObjectService) { const availableModels = await this.getModelDownloads(); - const model = availableModels.find((m) => m.name === modelId); + const model = availableModels.find((m) => m.model_id === modelId); if (!model) { throw Boom.notFound('Model not found'); } @@ -556,7 +574,7 @@ export class ModelsProvider { } const putResponse = await this._mlClient.putTrainedModel({ - model_id: model.name, + model_id: model.model_id, body: model.config, }); diff --git a/x-pack/plugins/ml/server/shared_services/providers/__mocks__/trained_models.ts b/x-pack/plugins/ml/server/shared_services/providers/__mocks__/trained_models.ts index 9af448058ce83..fa37f3d468fc3 100644 --- a/x-pack/plugins/ml/server/shared_services/providers/__mocks__/trained_models.ts +++ b/x-pack/plugins/ml/server/shared_services/providers/__mocks__/trained_models.ts @@ -16,7 +16,8 @@ const trainedModelsServiceMock = { deleteTrainedModel: jest.fn(), updateTrainedModelDeployment: jest.fn(), putTrainedModel: jest.fn(), - getELSER: jest.fn().mockResolvedValue({ name: '' }), + getELSER: jest.fn().mockResolvedValue({ model_id: '.elser_model_2' }), + getCuratedModelConfig: jest.fn().mockResolvedValue({ model_id: '.elser_model_2' }), } as jest.Mocked; export const createTrainedModelsProviderMock = () => diff --git a/x-pack/plugins/ml/server/shared_services/providers/trained_models.ts b/x-pack/plugins/ml/server/shared_services/providers/trained_models.ts index 4a1edbbcb3e4d..6b04a3e7580d9 100644 --- a/x-pack/plugins/ml/server/shared_services/providers/trained_models.ts +++ b/x-pack/plugins/ml/server/shared_services/providers/trained_models.ts @@ -8,7 +8,10 @@ import type * as estypes from '@elastic/elasticsearch/lib/api/typesWithBodyKey'; import type { CloudSetup } from '@kbn/cloud-plugin/server'; import type { KibanaRequest, SavedObjectsClientContract } from '@kbn/core/server'; -import type { GetElserOptions, ModelDefinitionResponse } from '@kbn/ml-trained-models-utils'; +import type { + GetModelDownloadConfigOptions, + ModelDefinitionResponse, +} from '@kbn/ml-trained-models-utils'; import type { MlInferTrainedModelRequest, MlStopTrainedModelDeploymentRequest, @@ -16,6 +19,7 @@ import type { UpdateTrainedModelDeploymentResponse, } from '../../lib/ml_client/types'; import { modelsProvider } from '../../models/model_management'; +import type { GetCuratedModelConfigParams } from '../../models/model_management/models_provider'; import type { GetGuards } from '../shared_services'; export interface TrainedModelsProvider { @@ -47,7 +51,8 @@ export interface TrainedModelsProvider { putTrainedModel( params: estypes.MlPutTrainedModelRequest ): Promise; - getELSER(params?: GetElserOptions): Promise; + getELSER(params?: GetModelDownloadConfigOptions): Promise; + getCuratedModelConfig(...params: GetCuratedModelConfigParams): Promise; }; } @@ -123,7 +128,7 @@ export function getTrainedModelsProvider( return mlClient.putTrainedModel(params); }); }, - async getELSER(params?: GetElserOptions) { + async getELSER(params?: GetModelDownloadConfigOptions) { return await guards .isFullLicense() .hasMlCapabilities(['canGetTrainedModels']) @@ -131,6 +136,14 @@ export function getTrainedModelsProvider( return modelsProvider(scopedClient, mlClient, cloud).getELSER(params); }); }, + async getCuratedModelConfig(...params: GetCuratedModelConfigParams) { + return await guards + .isFullLicense() + .hasMlCapabilities(['canGetTrainedModels']) + .ok(async ({ scopedClient, mlClient }) => { + return modelsProvider(scopedClient, mlClient, cloud).getCuratedModelConfig(...params); + }); + }, }; }, };