Skip to content

Commit

Permalink
Use orjson instead of json, when available (#17955)
Browse files Browse the repository at this point in the history
For `mypy -c 'import torch'`, the cache load time goes from 0.44s to
0.25s as measured by manager's data_json_load_time. If I time dump times
specifically, I see a saving of 0.65s to 0.07s. Overall, a pretty
reasonable perf win -- should we make it a required dependency?

See also #3456
  • Loading branch information
hauntsaninja committed Oct 20, 2024
1 parent 2cd2406 commit 7c27808
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 68 deletions.
12 changes: 6 additions & 6 deletions misc/apply-cache-diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
from __future__ import annotations

import argparse
import json
import os
import sys

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from mypy.metastore import FilesystemMetadataStore, MetadataStore, SqliteMetadataStore
from mypy.util import json_dumps, json_loads


def make_cache(input_dir: str, sqlite: bool) -> MetadataStore:
Expand All @@ -26,21 +26,21 @@ def make_cache(input_dir: str, sqlite: bool) -> MetadataStore:

def apply_diff(cache_dir: str, diff_file: str, sqlite: bool = False) -> None:
cache = make_cache(cache_dir, sqlite)
with open(diff_file) as f:
diff = json.load(f)
with open(diff_file, "rb") as f:
diff = json_loads(f.read())

old_deps = json.loads(cache.read("@deps.meta.json"))
old_deps = json_loads(cache.read("@deps.meta.json"))

for file, data in diff.items():
if data is None:
cache.remove(file)
else:
cache.write(file, data)
if file.endswith(".meta.json") and "@deps" not in file:
meta = json.loads(data)
meta = json_loads(data)
old_deps["snapshot"][meta["id"]] = meta["hash"]

cache.write("@deps.meta.json", json.dumps(old_deps))
cache.write("@deps.meta.json", json_dumps(old_deps))

cache.commit()

Expand Down
14 changes: 7 additions & 7 deletions misc/diff-cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from __future__ import annotations

import argparse
import json
import os
import sys
from collections import defaultdict
Expand All @@ -17,6 +16,7 @@
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from mypy.metastore import FilesystemMetadataStore, MetadataStore, SqliteMetadataStore
from mypy.util import json_dumps, json_loads


def make_cache(input_dir: str, sqlite: bool) -> MetadataStore:
Expand All @@ -33,7 +33,7 @@ def merge_deps(all: dict[str, set[str]], new: dict[str, set[str]]) -> None:

def load(cache: MetadataStore, s: str) -> Any:
data = cache.read(s)
obj = json.loads(data)
obj = json_loads(data)
if s.endswith(".meta.json"):
# For meta files, zero out the mtimes and sort the
# dependencies to avoid spurious conflicts
Expand Down Expand Up @@ -73,7 +73,7 @@ def main() -> None:
type_misses: dict[str, int] = defaultdict(int)
type_hits: dict[str, int] = defaultdict(int)

updates: dict[str, str | None] = {}
updates: dict[str, bytes | None] = {}

deps1: dict[str, set[str]] = {}
deps2: dict[str, set[str]] = {}
Expand All @@ -96,7 +96,7 @@ def main() -> None:
# so we can produce a much smaller direct diff of them.
if ".deps." not in s:
if obj2 is not None:
updates[s] = json.dumps(obj2)
updates[s] = json_dumps(obj2)
else:
updates[s] = None
elif obj2:
Expand All @@ -122,7 +122,7 @@ def main() -> None:
merge_deps(new_deps, root_deps)

new_deps_json = {k: list(v) for k, v in new_deps.items() if v}
updates["@root.deps.json"] = json.dumps(new_deps_json)
updates["@root.deps.json"] = json_dumps(new_deps_json)

# Drop updates to deps.meta.json for size reasons. The diff
# applier will manually fix it up.
Expand All @@ -136,8 +136,8 @@ def main() -> None:
print("hits", type_hits)
print("misses", type_misses)

with open(args.output, "w") as f:
json.dump(updates, f)
with open(args.output, "wb") as f:
f.write(json_dumps(updates))


if __name__ == "__main__":
Expand Down
45 changes: 18 additions & 27 deletions mypy/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
from mypy.stubinfo import legacy_bundled_packages, non_bundled_packages, stub_distribution_name
from mypy.types import Type
from mypy.typestate import reset_global_state, type_state
from mypy.util import json_dumps, json_loads
from mypy.version import __version__

# Switch to True to produce debug output related to fine-grained incremental
Expand Down Expand Up @@ -858,7 +859,7 @@ def load_fine_grained_deps(self, id: str) -> dict[str, set[str]]:
t0 = time.time()
if id in self.fg_deps_meta:
# TODO: Assert deps file wasn't changed.
deps = json.loads(self.metastore.read(self.fg_deps_meta[id]["path"]))
deps = json_loads(self.metastore.read(self.fg_deps_meta[id]["path"]))
else:
deps = {}
val = {k: set(v) for k, v in deps.items()}
Expand Down Expand Up @@ -911,8 +912,8 @@ def stats_summary(self) -> Mapping[str, object]:
return self.stats


def deps_to_json(x: dict[str, set[str]]) -> str:
return json.dumps({k: list(v) for k, v in x.items()}, separators=(",", ":"))
def deps_to_json(x: dict[str, set[str]]) -> bytes:
return json_dumps({k: list(v) for k, v in x.items()})


# File for storing metadata about all the fine-grained dependency caches
Expand Down Expand Up @@ -980,7 +981,7 @@ def write_deps_cache(

meta = {"snapshot": meta_snapshot, "deps_meta": fg_deps_meta}

if not metastore.write(DEPS_META_FILE, json.dumps(meta, separators=(",", ":"))):
if not metastore.write(DEPS_META_FILE, json_dumps(meta)):
manager.log(f"Error writing fine-grained deps meta JSON file {DEPS_META_FILE}")
error = True

Expand Down Expand Up @@ -1048,7 +1049,7 @@ def generate_deps_for_cache(manager: BuildManager, graph: Graph) -> dict[str, di

def write_plugins_snapshot(manager: BuildManager) -> None:
"""Write snapshot of versions and hashes of currently active plugins."""
snapshot = json.dumps(manager.plugins_snapshot, separators=(",", ":"))
snapshot = json_dumps(manager.plugins_snapshot)
if not manager.metastore.write(PLUGIN_SNAPSHOT_FILE, snapshot):
manager.errors.set_file(_cache_dir_prefix(manager.options), None, manager.options)
manager.errors.report(0, 0, "Error writing plugins snapshot", blocker=True)
Expand Down Expand Up @@ -1079,8 +1080,8 @@ def read_quickstart_file(
# just ignore it.
raw_quickstart: dict[str, Any] = {}
try:
with open(options.quickstart_file) as f:
raw_quickstart = json.load(f)
with open(options.quickstart_file, "rb") as f:
raw_quickstart = json_loads(f.read())

quickstart = {}
for file, (x, y, z) in raw_quickstart.items():
Expand Down Expand Up @@ -1148,10 +1149,10 @@ def _load_json_file(
manager.add_stats(metastore_read_time=time.time() - t0)
# Only bother to compute the log message if we are logging it, since it could be big
if manager.verbosity() >= 2:
manager.trace(log_success + data.rstrip())
manager.trace(log_success + data.rstrip().decode())
try:
t1 = time.time()
result = json.loads(data)
result = json_loads(data)
manager.add_stats(data_json_load_time=time.time() - t1)
except json.JSONDecodeError:
manager.errors.set_file(file, None, manager.options)
Expand Down Expand Up @@ -1343,8 +1344,8 @@ def find_cache_meta(id: str, path: str, manager: BuildManager) -> CacheMeta | No
# So that plugins can return data with tuples in it without
# things silently always invalidating modules, we round-trip
# the config data. This isn't beautiful.
plugin_data = json.loads(
json.dumps(manager.plugin.report_config_data(ReportConfigContext(id, path, is_check=True)))
plugin_data = json_loads(
json_dumps(manager.plugin.report_config_data(ReportConfigContext(id, path, is_check=True)))
)
if m.plugin_data != plugin_data:
manager.log(f"Metadata abandoned for {id}: plugin configuration differs")
Expand Down Expand Up @@ -1478,18 +1479,15 @@ def validate_meta(
"ignore_all": meta.ignore_all,
"plugin_data": meta.plugin_data,
}
if manager.options.debug_cache:
meta_str = json.dumps(meta_dict, indent=2, sort_keys=True)
else:
meta_str = json.dumps(meta_dict, separators=(",", ":"))
meta_bytes = json_dumps(meta_dict, manager.options.debug_cache)
meta_json, _, _ = get_cache_names(id, path, manager.options)
manager.log(
"Updating mtime for {}: file {}, meta {}, mtime {}".format(
id, path, meta_json, meta.mtime
)
)
t1 = time.time()
manager.metastore.write(meta_json, meta_str) # Ignore errors, just an optimization.
manager.metastore.write(meta_json, meta_bytes) # Ignore errors, just an optimization.
manager.add_stats(validate_update_time=time.time() - t1, validate_munging_time=t1 - t0)
return meta

Expand All @@ -1507,13 +1505,6 @@ def compute_hash(text: str) -> str:
return hash_digest(text.encode("utf-8"))


def json_dumps(obj: Any, debug_cache: bool) -> str:
if debug_cache:
return json.dumps(obj, indent=2, sort_keys=True)
else:
return json.dumps(obj, sort_keys=True, separators=(",", ":"))


def write_cache(
id: str,
path: str,
Expand Down Expand Up @@ -1566,8 +1557,8 @@ def write_cache(

# Serialize data and analyze interface
data = tree.serialize()
data_str = json_dumps(data, manager.options.debug_cache)
interface_hash = compute_hash(data_str)
data_bytes = json_dumps(data, manager.options.debug_cache)
interface_hash = hash_digest(data_bytes)

plugin_data = manager.plugin.report_config_data(ReportConfigContext(id, path, is_check=False))

Expand All @@ -1591,7 +1582,7 @@ def write_cache(
manager.trace(f"Interface for {id} is unchanged")
else:
manager.trace(f"Interface for {id} has changed")
if not metastore.write(data_json, data_str):
if not metastore.write(data_json, data_bytes):
# Most likely the error is the replace() call
# (see https://github.com/python/mypy/issues/3215).
manager.log(f"Error writing data JSON file {data_json}")
Expand Down Expand Up @@ -3568,4 +3559,4 @@ def write_undocumented_ref_info(
assert not ref_info_file.startswith(".")

deps_json = get_undocumented_ref_info_json(state.tree, type_map)
metastore.write(ref_info_file, json.dumps(deps_json, separators=(",", ":")))
metastore.write(ref_info_file, json_dumps(deps_json))
39 changes: 16 additions & 23 deletions mypy/metastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,14 @@ def getmtime(self, name: str) -> float:
"""

@abstractmethod
def read(self, name: str) -> str:
def read(self, name: str) -> bytes:
"""Read the contents of a metadata entry.
Raises FileNotFound if the entry does not exist.
"""

@abstractmethod
def write(self, name: str, data: str, mtime: float | None = None) -> bool:
def write(self, name: str, data: bytes, mtime: float | None = None) -> bool:
"""Write a metadata entry.
If mtime is specified, set it as the mtime of the entry. Otherwise,
Expand Down Expand Up @@ -86,16 +86,16 @@ def getmtime(self, name: str) -> float:

return int(os.path.getmtime(os.path.join(self.cache_dir_prefix, name)))

def read(self, name: str) -> str:
def read(self, name: str) -> bytes:
assert os.path.normpath(name) != os.path.abspath(name), "Don't use absolute paths!"

if not self.cache_dir_prefix:
raise FileNotFoundError()

with open(os.path.join(self.cache_dir_prefix, name)) as f:
with open(os.path.join(self.cache_dir_prefix, name), "rb") as f:
return f.read()

def write(self, name: str, data: str, mtime: float | None = None) -> bool:
def write(self, name: str, data: bytes, mtime: float | None = None) -> bool:
assert os.path.normpath(name) != os.path.abspath(name), "Don't use absolute paths!"

if not self.cache_dir_prefix:
Expand All @@ -105,7 +105,7 @@ def write(self, name: str, data: str, mtime: float | None = None) -> bool:
tmp_filename = path + "." + random_string()
try:
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(tmp_filename, "w") as f:
with open(tmp_filename, "wb") as f:
f.write(data)
os.replace(tmp_filename, path)
if mtime is not None:
Expand Down Expand Up @@ -135,27 +135,20 @@ def list_all(self) -> Iterable[str]:


SCHEMA = """
CREATE TABLE IF NOT EXISTS files (
CREATE TABLE IF NOT EXISTS files2 (
path TEXT UNIQUE NOT NULL,
mtime REAL,
data TEXT
data BLOB
);
CREATE INDEX IF NOT EXISTS path_idx on files(path);
CREATE INDEX IF NOT EXISTS path_idx on files2(path);
"""
# No migrations yet
MIGRATIONS: list[str] = []


def connect_db(db_file: str) -> sqlite3.Connection:
import sqlite3.dbapi2

db = sqlite3.dbapi2.connect(db_file)
db.executescript(SCHEMA)
for migr in MIGRATIONS:
try:
db.executescript(migr)
except sqlite3.OperationalError:
pass
return db


Expand All @@ -176,7 +169,7 @@ def _query(self, name: str, field: str) -> Any:
if not self.db:
raise FileNotFoundError()

cur = self.db.execute(f"SELECT {field} FROM files WHERE path = ?", (name,))
cur = self.db.execute(f"SELECT {field} FROM files2 WHERE path = ?", (name,))
results = cur.fetchall()
if not results:
raise FileNotFoundError()
Expand All @@ -188,12 +181,12 @@ def getmtime(self, name: str) -> float:
assert isinstance(mtime, float)
return mtime

def read(self, name: str) -> str:
def read(self, name: str) -> bytes:
data = self._query(name, "data")
assert isinstance(data, str)
assert isinstance(data, bytes)
return data

def write(self, name: str, data: str, mtime: float | None = None) -> bool:
def write(self, name: str, data: bytes, mtime: float | None = None) -> bool:
import sqlite3

if not self.db:
Expand All @@ -202,7 +195,7 @@ def write(self, name: str, data: str, mtime: float | None = None) -> bool:
if mtime is None:
mtime = time.time()
self.db.execute(
"INSERT OR REPLACE INTO files(path, mtime, data) VALUES(?, ?, ?)",
"INSERT OR REPLACE INTO files2(path, mtime, data) VALUES(?, ?, ?)",
(name, mtime, data),
)
except sqlite3.OperationalError:
Expand All @@ -213,13 +206,13 @@ def remove(self, name: str) -> None:
if not self.db:
raise FileNotFoundError()

self.db.execute("DELETE FROM files WHERE path = ?", (name,))
self.db.execute("DELETE FROM files2 WHERE path = ?", (name,))

def commit(self) -> None:
if self.db:
self.db.commit()

def list_all(self) -> Iterable[str]:
if self.db:
for row in self.db.execute("SELECT path FROM files"):
for row in self.db.execute("SELECT path FROM files2"):
yield row[0]
Loading

0 comments on commit 7c27808

Please sign in to comment.