Skip to content

Commit

Permalink
Report LLM token usage (#991)
Browse files Browse the repository at this point in the history
* report token usage at end of codemodder run

* move log
  • Loading branch information
clavedeluna authored Feb 6, 2025
1 parent d615dd7 commit b2baa38
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 28 deletions.
37 changes: 22 additions & 15 deletions src/codemodder/codemodder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from codemodder.codetf import CodeTF
from codemodder.context import CodemodExecutionContext
from codemodder.dependency import Dependency
from codemodder.llm import MisconfiguredAIClient
from codemodder.llm import MisconfiguredAIClient, TokenUsage, log_token_usage
from codemodder.logging import configure_logger, log_list, log_section, logger
from codemodder.project_analysis.file_parsers.package_store import PackageStore
from codemodder.project_analysis.python_repo_manager import PythonRepoManager
Expand Down Expand Up @@ -46,7 +46,7 @@ def find_semgrep_results(
return run_semgrep(context, yaml_files, files_to_analyze)


def log_report(context, output, elapsed_ms, files_to_analyze):
def log_report(context, output, elapsed_ms, files_to_analyze, token_usage):
log_section("report")
logger.info("scanned: %s files", len(files_to_analyze))
all_failures = context.get_failed_files()
Expand All @@ -62,6 +62,7 @@ def log_report(context, output, elapsed_ms, files_to_analyze):
len(set(all_changes)),
)
logger.info("report file: %s", output)
log_token_usage("All", token_usage)
logger.info("total elapsed: %s ms", elapsed_ms)
logger.info(" semgrep: %s ms", context.timer.get_time_ms("semgrep"))
logger.info(" parse: %s ms", context.timer.get_time_ms("parse"))
Expand All @@ -72,24 +73,30 @@ def log_report(context, output, elapsed_ms, files_to_analyze):
def apply_codemods(
context: CodemodExecutionContext,
codemods_to_run: Sequence[BaseCodemod],
):
) -> TokenUsage:
log_section("scanning")
token_usage = TokenUsage()

if not context.files_to_analyze:
logger.info("no files to scan")
return
return token_usage

if not codemods_to_run:
logger.info("no codemods to run")
return
return token_usage

# run codemods one at a time making sure to respect the given sequence
for codemod in codemods_to_run:
# NOTE: this may be used as a progress indicator by upstream tools
logger.info("running codemod %s", codemod.id)
codemod.apply(context)
codemod_token_usage = codemod.apply(context)
if codemod_token_usage:
log_token_usage(f"Codemod {codemod.id}", codemod_token_usage)
token_usage += codemod_token_usage

record_dependency_update(context.process_dependencies(codemod.id))
context.log_changes(codemod.id)
return token_usage


def record_dependency_update(dependency_results: dict[Dependency, PackageStore | None]):
Expand Down Expand Up @@ -128,7 +135,7 @@ def run(
codemod_registry: registry.CodemodRegistry | None = None,
sast_only: bool = False,
ai_client: bool = True,
) -> tuple[CodeTF | None, int]:
) -> tuple[CodeTF | None, int, TokenUsage]:
start = datetime.datetime.now()

codemod_registry = codemod_registry or registry.load_registered_codemods()
Expand All @@ -139,6 +146,7 @@ def run(
codemod_exclude = codemod_exclude or []

provider_registry = providers.load_providers()
token_usage = TokenUsage()

log_section("startup")
logger.info("codemodder: python/%s", __version__)
Expand All @@ -148,7 +156,7 @@ def run(
logger.error(
f"FileNotFoundError: [Errno 2] No such file or directory: '{file_name}'"
)
return None, 1
return None, 1, token_usage

repo_manager = PythonRepoManager(Path(directory))

Expand All @@ -168,7 +176,8 @@ def run(
)
except MisconfiguredAIClient as e:
logger.error(e)
return None, 3 # Codemodder instructions conflicted (according to spec)
# Codemodder instructions conflicted (according to spec)
return None, 3, token_usage

context.repo_manager.parse_project()

Expand All @@ -194,10 +203,7 @@ def run(
context.find_and_fix_paths,
)

apply_codemods(
context,
codemods_to_run,
)
token_usage = apply_codemods(context, codemods_to_run)

elapsed = datetime.datetime.now() - start
elapsed_ms = int(elapsed.total_seconds() * 1000)
Expand All @@ -217,8 +223,9 @@ def run(
output,
elapsed_ms,
[] if not codemods_to_run else context.files_to_analyze,
token_usage,
)
return codetf, 0
return codetf, 0, token_usage


def _run_cli(original_args) -> int:
Expand Down Expand Up @@ -258,7 +265,7 @@ def _run_cli(original_args) -> int:
logger.info("command: %s %s", Path(sys.argv[0]).name, " ".join(original_args))
configure_logger(argv.verbose, argv.log_format, argv.project_name)

_, status = run(
_, status, _ = run(
argv.directory,
argv.dry_run,
argv.output,
Expand Down
20 changes: 11 additions & 9 deletions src/codemodder/codemods/base_codemod.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from codemodder.codetf import DetectionTool, Reference
from codemodder.context import CodemodExecutionContext
from codemodder.file_context import FileContext
from codemodder.llm import TokenUsage
from codemodder.logging import logger
from codemodder.result import ResultSet

Expand Down Expand Up @@ -188,15 +189,15 @@ def _apply(
self,
context: CodemodExecutionContext,
rules: list[str],
) -> None:
) -> None | TokenUsage:
if self.provider and (
not (provider := context.providers.get_provider(self.provider))
or not provider.is_available
):
logger.warning(
"provider %s is not available, skipping codemod", self.provider
)
return
return None

if isinstance(self.detector, SemgrepRuleDetector):
if (
Expand All @@ -208,7 +209,7 @@ def _apply(
"no results from semgrep for %s, skipping analysis",
self.id,
)
return
return None

results: ResultSet | None = (
# It seems like semgrep doesn't like our fully-specified id format so pass in short name instead.
Expand All @@ -219,11 +220,11 @@ def _apply(

if results is not None and not results:
logger.debug("No results for %s", self.id)
return
return None

if not (files_to_analyze := self.get_files_to_analyze(context, results)):
logger.debug("No files matched for %s", self.id)
return
return None

process_file = functools.partial(
self._process_file, context=context, results=results, rules=rules
Expand All @@ -240,8 +241,9 @@ def _apply(
executor.shutdown(wait=True)

context.process_results(self.id, contexts)
return None

def apply(self, context: CodemodExecutionContext) -> None:
def apply(self, context: CodemodExecutionContext) -> None | TokenUsage:
"""
Apply the codemod with the given codemod execution context
Expand All @@ -257,7 +259,7 @@ def apply(self, context: CodemodExecutionContext) -> None:
:param context: The codemod execution context
"""
self._apply(context, [self._internal_name])
return self._apply(context, [self._internal_name])

def _process_file(
self,
Expand Down Expand Up @@ -355,8 +357,8 @@ def __init__(
if requested_rules:
self.requested_rules.extend(requested_rules)

def apply(self, context: CodemodExecutionContext) -> None:
self._apply(context, self.requested_rules)
def apply(self, context: CodemodExecutionContext) -> None | TokenUsage:
return self._apply(context, self.requested_rules)

def get_files_to_analyze(
self,
Expand Down
29 changes: 29 additions & 0 deletions src/codemodder/llm.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from __future__ import annotations

import os
from dataclasses import dataclass
from typing import TYPE_CHECKING

from typing_extensions import Self

try:
from openai import AzureOpenAI, OpenAI
except ImportError:
Expand All @@ -28,6 +31,8 @@
"setup_openai_llm_client",
"setup_azure_llama_llm_client",
"MisconfiguredAIClient",
"TokenUsage",
"log_token_usage",
]

models = ["gpt-4-turbo-2024-04-09", "gpt-4o-2024-05-13", "gpt-35-turbo-0125"]
Expand Down Expand Up @@ -115,3 +120,27 @@ def setup_azure_llama_llm_client() -> ChatCompletionsClient | None:

class MisconfiguredAIClient(ValueError):
pass


@dataclass
class TokenUsage:
completion_tokens: int = 0
prompt_tokens: int = 0

def __iadd__(self, other: Self) -> Self:
self.completion_tokens += other.completion_tokens
self.prompt_tokens += other.prompt_tokens
return self

@property
def total(self):
return self.completion_tokens + self.prompt_tokens


def log_token_usage(header: str, token_usage: TokenUsage):
logger.info(
"%s token usage\n\tcompletion_tokens = %s\n\tprompt_tokens = %s",
header,
token_usage.completion_tokens,
token_usage.prompt_tokens,
)
10 changes: 7 additions & 3 deletions tests/test_codemodder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from codemodder import run
from codemodder.codemodder import _run_cli, find_semgrep_results
from codemodder.diff import create_diff_from_tree
from codemodder.llm import TokenUsage
from codemodder.registry import load_registered_codemods
from codemodder.result import ResultSet
from codemodder.semgrep import run as semgrep_run
Expand All @@ -30,7 +31,9 @@ def disable_codemod_apply(mocker, request):
"test_run_codemod_name_or_id",
):
return
mocker.patch("codemodder.codemods.base_codemod.BaseCodemod.apply")
mocker.patch(
"codemodder.codemods.base_codemod.BaseCodemod.apply", return_value=TokenUsage()
)


@pytest.fixture(scope="function")
Expand Down Expand Up @@ -395,7 +398,8 @@ class TestRun:
def test_run_basic_call(self, mock_parse, dir_structure):
code_dir, codetf = dir_structure

codetf_output, status = run(code_dir, dry_run=True)
codetf_output, status, token_usage = run(code_dir, dry_run=True)
assert token_usage.total == 0
assert status == 0
assert codetf_output
assert codetf_output.run.directory == str(code_dir)
Expand All @@ -406,7 +410,7 @@ def test_run_basic_call(self, mock_parse, dir_structure):
def test_run_with_output(self, mock_parse, dir_structure):
code_dir, codetf = dir_structure

codetf_output, status = run(
codetf_output, status, _ = run(
code_dir,
output=codetf,
dry_run=True,
Expand Down
10 changes: 9 additions & 1 deletion tests/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from codemodder.llm import MODELS, models
from codemodder.llm import MODELS, TokenUsage, models


class TestModels:
Expand All @@ -20,3 +20,11 @@ def test_model_get_name_from_env(self, mocker, model):
},
)
assert getattr(MODELS, attr_name) == name


def test_token_usage():
token_usage = TokenUsage()
token_usage += TokenUsage(10, 5)
assert token_usage.completion_tokens == 10
assert token_usage.prompt_tokens == 5
assert token_usage.total == 15

0 comments on commit b2baa38

Please sign in to comment.