From 9e7a477bb578a9be8b7986ecac36ac2b656c969b Mon Sep 17 00:00:00 2001 From: "nate.river" Date: Sun, 17 Mar 2024 01:00:52 +0800 Subject: [PATCH] add bce example (#930) --- llm/inference/bce/run_bce-embedding.py | 21 +++++++++++++++++++++ llm/inference/bce/run_bce-reranker.py | 24 ++++++++++++++++++++++++ 2 files changed, 45 insertions(+) create mode 100644 llm/inference/bce/run_bce-embedding.py create mode 100644 llm/inference/bce/run_bce-reranker.py diff --git a/llm/inference/bce/run_bce-embedding.py b/llm/inference/bce/run_bce-embedding.py new file mode 100644 index 000000000..99404ba76 --- /dev/null +++ b/llm/inference/bce/run_bce-embedding.py @@ -0,0 +1,21 @@ +from mindspore import ops +from mindnlp.transformers import AutoModel, AutoTokenizer, AutoModelForSequenceClassification + +hf_token = 'your_hf_token' + +# list of sentences +sentences = ['sentence_0', 'sentence_1'] + +# init model and tokenizer +tokenizer = AutoTokenizer.from_pretrained('maidalun1020/bce-embedding-base_v1', token=hf_token) +model = AutoModel.from_pretrained('maidalun1020/bce-embedding-base_v1', token=hf_token) + +# get inputs +inputs = tokenizer(sentences, padding=True, truncation=True, max_length=512, return_tensors="ms") +inputs = {k: v for k, v in inputs.items()} + +# get embeddings +outputs = model(**inputs, return_dict=True) +embeddings = outputs.last_hidden_state[:, 0] # cls pooler +embeddings = embeddings / embeddings.norm(dim=1, keepdim=True) # normalize +print(embeddings) \ No newline at end of file diff --git a/llm/inference/bce/run_bce-reranker.py b/llm/inference/bce/run_bce-reranker.py new file mode 100644 index 000000000..7093714bd --- /dev/null +++ b/llm/inference/bce/run_bce-reranker.py @@ -0,0 +1,24 @@ +from mindspore import ops +from mindnlp.transformers import AutoModel, AutoTokenizer, AutoModelForSequenceClassification + +hf_token = 'your_hf_token' + +# init model and tokenizer +tokenizer = AutoTokenizer.from_pretrained('maidalun1020/bce-reranker-base_v1', token=hf_token) +model = AutoModelForSequenceClassification.from_pretrained('maidalun1020/bce-reranker-base_v1', token=hf_token) + +# your query and corresponding passages +query = "上海天气" +passages = ["北京美食", "上海气候"] + +# construct sentence pairs +sentence_pairs = [[query, passage] for passage in passages] + +# get inputs +inputs = tokenizer(sentence_pairs, padding=True, truncation=True, max_length=512, return_tensors="ms") +inputs_on_device = {k: v for k, v in inputs.items()} + +# calculate scores +scores = model(**inputs_on_device, return_dict=True).logits.view(-1,).float() +scores = ops.sigmoid(scores) +print(scores)