Skip to content

Commit

Permalink
Merge pull request #167 from citadel-ai/rm-yapf
Browse files Browse the repository at this point in the history
[nit] Remove yapf
  • Loading branch information
liwii authored Nov 15, 2024
2 parents d9ff16f + 421e680 commit 81a72ef
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 19 deletions.
2 changes: 1 addition & 1 deletion .devcontainer/devcontainer.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
// Use 'postCreateCommand' to run commands after the container is created.
// PyTorch 2.1.0 causes segmentation fault in aarch64, so we pin the version in the dev container until the bug is fixed.
// Ref: https://github.com/pytorch/pytorch/issues/110819
"postCreateCommand": "curl https://sh.rustup.rs -sSf | bash -s -- -y && . $HOME/.cargo/env && pip install --upgrade pip && pip install yapf==0.40.1 && python -m pip install -e .[no-local-llm,dev]",
"postCreateCommand": "curl https://sh.rustup.rs -sSf | bash -s -- -y && . $HOME/.cargo/env && pip install --upgrade pip && python -m pip install -e .[no-local-llm,dev]",
"customizations": {
// Configure properties specific to VS Code.
"vscode": {
Expand Down
32 changes: 14 additions & 18 deletions src/langcheck/metrics/zh/reference_free_text_quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,10 @@ def sentiment(
_predict_result = _sentiment_pipeline(generated_outputs) # type: ignore[reportGeneralTypeIssues]
# if predicted result is 'Positive', use the score directly
# else, use 1 - score as the sentiment score
# yapf: disable
scores = [
1 - x["score"] if x["label"] == _model_id2label[0] else x["score"] # type: ignore[reportGeneralTypeIssues]
for x in _predict_result # type: ignore[reportGeneralTypeIssues]
for x in _predict_result # type: ignore[reportGeneralTypeIssues]
]
# yapf: enable
return MetricValue(
metric_name="sentiment",
metric_inputs=metric_inputs,
Expand Down Expand Up @@ -200,10 +198,8 @@ def _toxicity_local(generated_outputs: list[str]) -> list[float]:
toxicity_scores = []
for item_predict_proba in _predict_results: # type: ignore[reportOptionalIterable]
for label_proba in item_predict_proba: # type: ignore[reportGeneralTypeIssues]
# yapf: disable
if label_proba["label"] == _model_id2label[0]: # type: ignore[reportGeneralTypeIssues]
toxicity_scores.append(1 - label_proba["score"]) # type: ignore[reportGeneralTypeIssues]
# yapf: enable
return toxicity_scores # type: ignore[reportGeneralTypeIssues]


Expand Down Expand Up @@ -240,31 +236,32 @@ def xuyaochen_report_readability(
prompts=prompts,
required_params=["generated_outputs"],
)
# yapf: disable
tokenizer = hanlp.load(
hanlp.pretrained.tok.FINE_ELECTRA_SMALL_ZH # type: ignore[reportGeneralTypeIssues]
)
postagger = hanlp.load(
hanlp.pretrained.pos.CTB9_POS_RADICAL_ELECTRA_SMALL # type: ignore[reportGeneralTypeIssues]
hanlp.pretrained.pos.CTB9_POS_RADICAL_ELECTRA_SMALL # type: ignore[reportGeneralTypeIssues]
)

pos_pipeline = hanlp.pipeline().\
append(hanlp.utils.rules.split_sentence) # type: ignore[reportGeneralTypeIssues]
pos_pipeline = hanlp.pipeline().append(hanlp.utils.rules.split_sentence) # type: ignore[reportGeneralTypeIssues]
pos_pipeline = pos_pipeline.append(tokenizer).append(postagger)

tokenize_pipeline = hanlp.pipeline().\
append(hanlp.utils.rules.split_sentence) # type: ignore[reportGeneralTypeIssues]
tokenize_pipeline = hanlp.pipeline().append(
hanlp.utils.rules.split_sentence # type: ignore[reportGeneralTypeIssues]
)
tokenize_pipeline = tokenize_pipeline.append(tokenizer)
# OUTPUT: List[List[List[TOKEN]]]
output_tokens = list(map(tokenize_pipeline, generated_outputs))
# List[List[List[POS]]]
output_pos = list(map(pos_pipeline, generated_outputs))

def count_tokens(sent_tokens: list[str]) -> int:
count = sum([
not hanlp.utils.string_util.ispunct(token) for token in # type: ignore[reportGeneralTypeIssues]
sent_tokens
])
count = sum(
[
not hanlp.utils.string_util.ispunct(token) # type: ignore[reportGeneralTypeIssues]
for token in sent_tokens
]
)
return count

def count_postags(sent_poses: list[str]) -> int:
Expand All @@ -287,10 +284,9 @@ def calc_r2(content: list[list[str]]) -> float:
else:
return sum(pos_count_by_sentence) / len(pos_count_by_sentence)

r1 = list(map(calc_r1, output_tokens)) # type: ignore[reportGeneralTypeIssues]
r2 = list(map(calc_r2, output_pos)) # type: ignore[reportGeneralTypeIssues]
r1 = list(map(calc_r1, output_tokens)) # type: ignore[reportGeneralTypeIssues]
r2 = list(map(calc_r2, output_pos)) # type: ignore[reportGeneralTypeIssues]
r3 = [(r1_score + r2_score) * 0.5 for r1_score, r2_score in zip(r1, r2)]
# yapf: enable
return MetricValue(
metric_name="readability",
metric_inputs=metric_inputs,
Expand Down

0 comments on commit 81a72ef

Please sign in to comment.