diff --git a/examples/Pipfile.ok.extras-list b/examples/Pipfile.ok.extras-list index f7e4fb6..4daa1df 100644 --- a/examples/Pipfile.ok.extras-list +++ b/examples/Pipfile.ok.extras-list @@ -1,4 +1,6 @@ # package extras are a list [packages] msal = {version= "==1.20.0", extras = ["broker"]} +six = 1.11 +zipp = "*" diff --git a/examples/Pipfile.ok.multiple-sources b/examples/Pipfile.ok.multiple-sources new file mode 100644 index 0000000..fa47e44 --- /dev/null +++ b/examples/Pipfile.ok.multiple-sources @@ -0,0 +1,15 @@ +[[source]] +url = "https://pypi.org/simple" +verify_ssl = true +name = "pypi" + +[[source]] +url = "https://download.pytorch.org/whl/cu113/" +verify_ssl = false +name = "pytorch" + +[dev-packages] + +[packages] +torch = {version="*", index="pytorch"} +numpy = {version="*"} diff --git a/src/plette/lockfiles.py b/src/plette/lockfiles.py index f573514..b04d67b 100644 --- a/src/plette/lockfiles.py +++ b/src/plette/lockfiles.py @@ -1,53 +1,41 @@ -# pylint: disable=missing-module-docstring,missing-class-docstring -# pylint: disable=missing-function-docstring -# pylint: disable=no-member - -import dataclasses import json import numbers import collections.abc as collections_abc -from dataclasses import dataclass, field, asdict -from typing import Optional -from .models import BaseModel, Meta, PackageCollection, Package, remove_empty_values +from .models import DataModel, Meta, PackageCollection -PIPFILE_SPEC_CURRENT = 6 +class _LockFileEncoder(json.JSONEncoder): + """A specilized JSON encoder to convert loaded data into a lock file. -def flatten_versions(d): - copy = {} - # Iterate over a copy of the dictionary - for key, value in d.items(): - # If the key is "version", replace the key with the value - copy[key] = value["version"] - return copy - - -class DCJSONEncoder(json.JSONEncoder): - def default(self, o): - if dataclasses.is_dataclass(o): - o = dataclasses.asdict(o) - if "_meta" in o: - o["_meta"]["pipfile-spec"] = o["_meta"].pop("pipfile_spec") - o["_meta"]["hash"] = {o["_meta"]["hash"]["name"]: o["_meta"]["hash"]["value"]} - o["_meta"]["sources"] = o["_meta"]["sources"].pop("sources") - - remove_empty_values(o) - - for section in ["default", "develop"]: - try: - o[section] = flatten_versions(o[section]) - except KeyError: - continue - # add silly default values - if "develop" not in o: - o["develop"] = {} - if "requires" not in o["_meta"]: - o["_meta"]["requires"] = {} - return o - return super().default(o) + This adds a few characteristics to the encoder: + + * The JSON is always prettified with indents and spaces. + * The output is always UTF-8-encoded text, never binary, even on Python 2. + """ + def __init__(self): + super(_LockFileEncoder, self).__init__( + indent=4, separators=(",", ": "), sort_keys=True, + ) + + def encode(self, obj): + content = super(_LockFileEncoder, self).encode(obj) + if not isinstance(content, str): + content = content.decode("utf-8") + content += "\n" + return content + + def iterencode(self, obj): + for chunk in super(_LockFileEncoder, self).iterencode(obj): + if not isinstance(chunk, str): + chunk = chunk.decode("utf-8") + yield chunk + yield "\n" + + +PIPFILE_SPEC_CURRENT = 6 def _copy_jsonsafe(value): @@ -64,99 +52,126 @@ def _copy_jsonsafe(value): return str(value) -@dataclass -class Lockfile(BaseModel): - """Representation of a Pipfile.lock.""" - - _meta: Optional[Meta] - default: Optional[dict] = field(default_factory=dict) - develop: Optional[dict] = field(default_factory=dict) - - def __post_init__(self): - """Run validation methods if declared. - The validation method can be a simple check - that raises ValueError or a transformation to - the field value. - The validation is performed by calling a function named: - `validate_(self, value) -> field.type` - """ - super().__post_init__() - self.meta = self._meta - - def validate__meta(self, value): - return self.validate_meta(value) - - def validate_meta(self, value): - if "_meta" in value: - value = value["_meta"] - if 'pipfile-spec' in value: - value['pipfile_spec'] = value.pop('pipfile-spec') - return Meta(**value) - - def validate_default(self, value): - packages = {} - for name, spec in value.items(): - packages[name] = Package(spec) - return packages +class Lockfile(DataModel): + """Representation of a Pipfile.lock. + """ + __SCHEMA__ = { + "_meta": {"type": "dict", "required": True}, + "default": {"type": "dict", "required": True}, + "develop": {"type": "dict", "required": True}, + } + + @classmethod + def validate(cls, data): + for key, value in data.items(): + if key == "_meta": + Meta.validate(value) + else: + PackageCollection.validate(value) @classmethod - def load(cls, fh, encoding=None): + def load(cls, f, encoding=None): if encoding is None: - data = json.load(fh) + data = json.load(f) else: - data = json.loads(fh.read().decode(encoding)) - return cls(**data) + data = json.loads(f.read().decode(encoding)) + return cls(data) @classmethod def with_meta_from(cls, pipfile, categories=None): data = { "_meta": { - "hash": pipfile.get_hash().__dict__, + "hash": _copy_jsonsafe(pipfile.get_hash()._data), "pipfile-spec": PIPFILE_SPEC_CURRENT, - "requires": _copy_jsonsafe(getattr(pipfile, "requires", {})), + "requires": _copy_jsonsafe(pipfile._data.get("requires", {})), + "sources": _copy_jsonsafe(pipfile.sources._data), }, } - - data["_meta"].update(asdict(pipfile.sources)) - if categories is None: - data["default"] = _copy_jsonsafe(getattr(pipfile, "packages", {})) - data["develop"] = _copy_jsonsafe(getattr(pipfile, "dev-packages", {})) + data["default"] = _copy_jsonsafe(pipfile._data.get("packages", {})) + data["develop"] = _copy_jsonsafe(pipfile._data.get("dev-packages", {})) else: for category in categories: - if category in ["default", "packages"]: - data["default"] = _copy_jsonsafe(getattr(pipfile,"packages", {})) - elif category in ["develop", "dev-packages"]: - data["develop"] = _copy_jsonsafe( - getattr(pipfile,"dev-packages", {})) + if category == "default" or category == "packages": + data["default"] = _copy_jsonsafe(pipfile._data.get("packages", {})) + elif category == "develop" or category == "dev-packages": + data["develop"] = _copy_jsonsafe(pipfile._data.get("dev-packages", {})) else: - data[category] = _copy_jsonsafe(getattr(pipfile, category, {})) + data[category] = _copy_jsonsafe(pipfile._data.get(category, {})) if "default" not in data: - data["default"] = {} + data["default"] = {} if "develop" not in data: data["develop"] = {} return cls(data) def __getitem__(self, key): - value = self[key] + value = self._data[key] try: if key == "_meta": - return Meta(**value) - return PackageCollection(value) + return Meta(value) + else: + return PackageCollection(value) except KeyError: return value + def __setitem__(self, key, value): + if isinstance(value, DataView): + self._data[key] = value._data + else: + self._data[key] = value + def is_up_to_date(self, pipfile): return self.meta.hash == pipfile.get_hash() - def dump(self, fh): - json.dump(self, fh, cls=DCJSONEncoder) - self.meta = self._meta + def dump(self, f, encoding=None): + encoder = _LockFileEncoder() + if encoding is None: + for chunk in encoder.iterencode(self._data): + f.write(chunk) + else: + content = encoder.encode(self._data) + f.write(content.encode(encoding)) @property def meta(self): - return self._meta + try: + return self["_meta"] + except KeyError: + raise AttributeError("meta") @meta.setter def meta(self, value): - self._meta = value + self["_meta"] = value + + @property + def _meta(self): + try: + return self["_meta"] + except KeyError: + raise AttributeError("meta") + + @_meta.setter + def _meta(self, value): + self["_meta"] = value + + @property + def default(self): + try: + return self["default"] + except KeyError: + raise AttributeError("default") + + @default.setter + def default(self, value): + self["default"] = value + + @property + def develop(self): + try: + return self["develop"] + except KeyError: + raise AttributeError("develop") + + @develop.setter + def develop(self, value): + self["develop"] = value diff --git a/src/plette/models.py b/src/plette/models.py deleted file mode 100644 index db6e6dd..0000000 --- a/src/plette/models.py +++ /dev/null @@ -1,355 +0,0 @@ -# pylint: disable=missing-module-docstring,missing-class-docstring -# pylint: disable=missing-function-docstring -# pylint: disable=no-member -# pylint: disable=too-few-public-methods -import os -import re -import shlex - - -from dataclasses import dataclass - -from typing import Optional, List, Union - - -class ValidationError(ValueError): - pass - - -def remove_empty_values(d): - # Iterate over a copy of the dictionary - for key, value in list(d.items()): - # If the value is a dictionary, call the function recursively - if isinstance(value, dict): - remove_empty_values(value) - # If the dictionary is empty, remove the key - if not value: - del d[key] - # If the value is None or an empty string, remove the key - elif value is None or value == '': - del d[key] - - -class BaseModel: - - def __post_init__(self): - """Run validation methods if declared. - The validation method can be a simple check - that raises ValueError or a transformation to - the field value. - The validation is performed by calling a function named: - `validate_(self, value) -> field.type` - """ - for name, _ in self.__dataclass_fields__.items(): - if (method := getattr(self, f"validate_{name}", None)): - setattr(self, name, method(getattr(self, name))) - - -@dataclass -class Hash(BaseModel): - - name: str - value: str - - def validate_name(self, value): - if not isinstance(value, str): - raise ValueError("Hash.name must be a string") - - return value - - def validate_value(self, value): - if not isinstance(value, str): - raise ValueError("Hash.value must be a string") - - return value - - @classmethod - def from_hash(cls, ins): - """Interpolation to the hash result of `hashlib`. - """ - return cls(name=ins.name, value=ins.hexdigest()) - - @classmethod - def from_dict(cls, value): - """parse a depedency line and create an Hash object""" - try: - name, value = list(value.items())[0] - except AttributeError: - name, value = value.split(":", 1) - return cls(name, value) - - @classmethod - def from_line(cls, value): - """parse a dependecy line and create a Hash object""" - try: - name, value = value.split(":", 1) - except AttributeError: - name, value = list(value.items())[0] - return cls(name, value) - - def __eq__(self, other): - if not isinstance(other, Hash): - raise TypeError(f"cannot compare Hash with {type(other).__name__!r}") - return self.value == other.value - - def as_line(self): - return f"{self.name}:{self.value}" - - -@dataclass -class Source(BaseModel): - """Information on a "simple" Python package index. - - This could be PyPI, or a self-hosted index server, etc. The server - specified by the `url` attribute is expected to provide the "simple" - package API. - """ - name: str - verify_ssl: bool - url: str - - @property - def url_expanded(self): - return os.path.expandvars(self.url) - - def validate_verify_ssl(self, value): - if not isinstance(value, bool): - raise ValidationError("verify_ssl: must be of boolean type") - return value - - -@dataclass -class PackageSpecfiers(BaseModel): - - extras: List[str] - - def validate_extras(self, value): - if not isinstance(value, list): - raise ValidationError("Extras must be a list") - - -@dataclass -class Package(BaseModel): - - version: Union[Optional[str],Optional[dict]] = "*" - specifiers: Optional[PackageSpecfiers] = None - editable: Optional[bool] = None - extras: Optional[PackageSpecfiers] = None - path: Optional[str] = None - - def validate_extras(self, value): - if value is None: - return value - if not (isinstance(value, list) and all(isinstance(i, str) for i in value)): - raise ValidationError("Extras must be a list or None") - return value - - def validate_version(self, value): - if isinstance(value, dict): - return value - if isinstance(value, str): - return value - if value is None: - return "*" - - raise ValidationError(f"Unknown type {type(value)} for version") - - -@dataclass(init=False) -class Script(BaseModel): - - script: Union[str, List[str]] - - def __init__(self, script): - - if isinstance(script, str): - script = shlex.split(script) - self._parts = [script[0]] - self._parts.extend(script[1:]) - - def validate_script(self, value): - if not (isinstance(value, str) or - (isinstance(value, list) and all(isinstance(i, str) for i in value)) - ): - raise ValueError("script must be a string or a list of strings") - - def __repr__(self): - return f"Script({self._parts!r})" - - @property - def command(self): - return self._parts[0] - - @property - def args(self): - return self._parts[1:] - - def cmdify(self, extra_args=None): - """Encode into a cmd-executable string. - - This re-implements CreateProcess's quoting logic to turn a list of - arguments into one single string for the shell to interpret. - - * All double quotes are escaped with a backslash. - * Existing backslashes before a quote are doubled, so they are all - escaped properly. - * Backslashes elsewhere are left as-is; cmd will interpret them - literally. - - The result is then quoted into a pair of double quotes to be grouped. - - An argument is intentionally not quoted if it does not contain - whitespaces. This is done to be compatible with Windows built-in - commands that don't work well with quotes, e.g. everything with `echo`, - and DOS-style (forward slash) switches. - - The intended use of this function is to pre-process an argument list - before passing it into ``subprocess.Popen(..., shell=True)``. - - See also: https://docs.python.org/3/library/subprocess.html - """ - parts = list(self._parts) - if extra_args: - parts.extend(extra_args) - return " ".join( - arg if not next(re.finditer(r'\s', arg), None) - else '"{0}"'.format(re.sub(r'(\\*)"', r'\1\1\\"', arg)) - for arg in parts - ) - - -@dataclass -class PackageCollection(BaseModel): - - packages: List[Package] - - def validate_packages(self, value): - if isinstance(value, dict): - packages = {} - for k, v in value.items(): - if isinstance(v, dict): - packages[k] = Package(**v) - else: - packages[k] = Package(version=v) - return packages - return value - - -@dataclass -class ScriptCollection(BaseModel): - scripts: List[Script] - - -@dataclass -class SourceCollection(BaseModel): - - sources: List[Source] - - def validate_sources(self, value): - sources = [] - for v in value: - if isinstance(v, dict): - sources.append(Source(**v)) - elif isinstance(v, Source): - sources.append(v) - return sources - - def __iter__(self): - return (d for d in self.sources) - - def __getitem__(self, key): - if isinstance(key, slice): - return SourceCollection(self.sources[key]) - if isinstance(key, int): - src = self.sources[key] - if isinstance(src, dict): - return Source(**key) - if isinstance(src, Source): - return src - raise TypeError(f"Unextepcted type {type(src)}") - - def __len__(self): - return len(self.sources) - - def __setitem__(self, key, value): - if isinstance(key, slice): - self.sources[key] = value - elif isinstance(value, Source): - self.sources.append(value) - elif isinstance(value, list): - self.sources.extend(value) - else: - raise TypeError(f"Unextepcted type {type(value)} for {value}") - - def __delitem__(self, key): - del self.sources[key] - - -@dataclass -class Requires(BaseModel): - python_version: Optional[str] = None - python_full_version: Optional[str] = None - - -META_SECTIONS = { - "hash": Hash, - "requires": Requires, - "sources": SourceCollection, -} - - -@dataclass -class PipfileSection(BaseModel): - - """ - Dummy pipfile validator that needs to be completed in a future PR - Hint: many pipfile features are undocumented in pipenv/project.py - """ - - -@dataclass -class Meta(BaseModel): - - hash: Hash - pipfile_spec: str - requires: Requires - sources: SourceCollection - - @classmethod - def from_dict(cls, d: dict) -> "Meta": - return cls(**{k.replace('-', '_'): v for k, v in d.items()}) - - def validate_hash(self, value): - try: - return Hash(**value) - except TypeError: - return Hash.from_line(value) - - def validate_requires(self, value): - return Requires(value) - - def validate_sources(self, value): - return SourceCollection(value) - - def validate_pipfile_spec(self, value): - if int(value) != 6: - raise ValueError('Only pipefile-spec version 6 is supported') - return value - - -@dataclass -class Pipenv(BaseModel): - """Represent the [pipenv] section in Pipfile""" - allow_prereleases: Optional[bool] = False - install_search_all_sources: Optional[bool] = True - - def validate_allow_prereleases(self, value): - if not isinstance(value, bool): - raise ValidationError('allow_prereleases must be a boolean') - return value - - def validate_install_search_all_sources(self, value): - if not isinstance(value, bool): - raise ValidationError('install_search_all_sources must be a boolean') - - return value diff --git a/src/plette/models/__init__.py b/src/plette/models/__init__.py new file mode 100644 index 0000000..babe1f9 --- /dev/null +++ b/src/plette/models/__init__.py @@ -0,0 +1,26 @@ +__all__ = [ + "DataView", "DataModelCollection", "DataModelMapping", "DataModelSequence", + "validate", "DataValidationError", + "Hash", "Package", "Requires", "Source", "Script", + "Meta", "PackageCollection", "ScriptCollection", "SourceCollection", +] + +from .base import ( + DataModel, DataModelCollection, DataModelMapping, DataModelSequence, + DataValidationError, +) + +from .hashes import Hash +from .packages import Package +from .scripts import Script +from .sources import Source + +from .sections import ( + Meta, + Requires, + PackageCollection, + Pipenv, + PipfileSection, + ScriptCollection, + SourceCollection, +) diff --git a/src/plette/models/base.py b/src/plette/models/base.py new file mode 100644 index 0000000..c3d6937 --- /dev/null +++ b/src/plette/models/base.py @@ -0,0 +1,130 @@ +class DataValidationError(ValueError): + pass + + +class DataModel: + + def __init__(self, data): + self.validate(data) + self._data = data + + def __repr__(self): + return "{0}({1!r})".format(type(self).__name__, self._data) + + def __eq__(self, other): + if not isinstance(other, type(self)): + raise TypeError( + "cannot compare {0!r} with {1!r}".format( + type(self).__name__, type(other).__name__ + ) + ) + return self._data == other._data + + def __getitem__(self, key): + return self._data[key] + + def __setitem__(self, key, value): + self._data[key] = value + + def __delitem__(self, key): + del self._data[key] + + def get(self, key, default=None): + try: + return self[key] + except KeyError: + return default + + @classmethod + def validate(cls, data): + for k, v in cls.__SCHEMA__.items(): + if k not in data: + raise DataValidationError(f"Missing required field: {k}") + if not isinstance(data[k], v): + raise DataValidationError(f"Invalid type for field {k}: {type(data[k])}") + + if hasattr(cls, "__OPTIONAL__"): + for k, v in cls.__OPTIONAL__.items(): + if k in data and not isinstance(data[k], v): + raise DataValidationError(f"Invalid type for field {k}: {type(data[k])}") + + +class DataModelCollection(DataModel): + """A homogeneous collection of data views. + + Subclasses are expected to assign a class attribute `item_class` to specify + the type of items it contains. This class will be used to coerce return + values when accessed. The item class should conform to the `DataModel` + protocol. + + You should not instantiate an instance from this class, but from one of its + subclasses instead. + """ + + item_class = None + + def __repr__(self): + return "{0}({1!r})".format(type(self).__name__, self._data) + + def __len__(self): + return len(self._data) + + def __getitem__(self, key): + return self.item_class(self._data[key]) + + def __setitem__(self, key, value): + if isinstance(value, DataModel): + value = value._data + self._data[key] = value + + def __delitem__(self, key): + del self._data[key] + + +class DataModelSequence(DataModelCollection): + """A sequence of data views. + + Each entry is an instance of `item_class`. + """ + + @classmethod + def validate(cls, data): + for d in data: + cls.item_class.validate(d) + + def __iter__(self): + return (self.item_class(d) for d in self._data) + + def __getitem__(self, key): + if isinstance(key, slice): + return type(self)(self._data[key]) + return super().__getitem__(key) + + def append(self, value): + if isinstance(value, DataModel): + value = value._data + self._data.append(value) + + +class DataModelMapping(DataModelCollection): + """A mapping of data views. + + The keys are primitive values, while values are instances of `item_class`. + """ + + @classmethod + def validate(cls, data): + for d in data.values(): + cls.item_class.validate(d) + + def __iter__(self): + return iter(self._data) + + def keys(self): + return self._data.keys() + + def values(self): + return [self[k] for k in self._data] + + def items(self): + return [(k, self[k]) for k in self._data] diff --git a/src/plette/models/hashes.py b/src/plette/models/hashes.py new file mode 100644 index 0000000..75c4c2c --- /dev/null +++ b/src/plette/models/hashes.py @@ -0,0 +1,69 @@ +from .base import DataModel, DataValidationError + + +class Hash(DataModel): + """A hash. + """ + item_class = "Hash" + + __SCHEMA__ = { + } + + __OPTIONAL__ = { + "name": str, + "md5": str, + "sha256": str, + "digest": str, + } + + def __init__(self, data): + self.validate(data) + self._data = data + if "name" in data: + self.name = data["name"] + try: + self.digest = data["digest"] + except KeyError: + self.digest = data["value"] + elif "md5" in data: + self.name = "md5" + self.digest = data["md5"] + elif "sha256" in data: + self.name = "sha256" + self.digest = data["sha256"] + + @classmethod + def validate(cls, data): + for k, v in cls.__SCHEMA__.items(): + if k not in data: + raise DataValidationError(f"Missing required field: {k}") + if not isinstance(data[k], v): + raise DataValidationError(f"Invalid type for field {k}: {type(data[k])}") + + @classmethod + def from_hash(cls, ins): + """Interpolation to the hash result of `hashlib`. + """ + return cls(data={ins.name: ins.hexdigest()}) + + @classmethod + def from_line(cls, value): + try: + name, value = value.split(":", 1) + except ValueError: + name = "sha256" + return cls(data={"name":name, "value": value}) + + def __eq__(self, other): + if not isinstance(other, Hash): + raise TypeError("cannot compare Hash with {0!r}".format( + type(other).__name__, + )) + return self._data == other._data + + @property + def value(self): + return self.digest + + def as_line(self): + return "{0[0]}:{0[1]}".format(next(iter(self._data.items()))) diff --git a/src/plette/models/packages.py b/src/plette/models/packages.py new file mode 100644 index 0000000..596a399 --- /dev/null +++ b/src/plette/models/packages.py @@ -0,0 +1,56 @@ +import tomlkit + +from .base import DataModel, DataValidationError + +class PackageSpecfiers(DataModel): + # TODO: one could add here more validation for path editable + # and more stuff which is currently allowed and undocumented + __SCHEMA__ = {} + __OPTIONAL__ = { + "editable": bool, + "version": str, + "extras": list + } + + +class Package(DataModel): + """A package requirement specified in a Pipfile. + + This is the base class of variants appearing in either `[packages]` or + `[dev-packages]` sections of a Pipfile. + """ + # The extra layer is intentional. Cerberus does not allow top-level keys + # to have oneof_schema (at least I can't do it), so we wrap this in a + # top-level key. The Requirement model class implements extra hacks to + # make this work. + __OPTIONAL__ = { + "PackageSpecfiers": (str, dict) + } + + @classmethod + def validate(cls, data): + if isinstance(data, (str, tomlkit.items.Float, tomlkit.items.Integer)): + return + if isinstance(data, dict): + PackageSpecfiers.validate(data) + else: + raise DataValidationError(f"invalid type for package data: {type(data)}") + + def __getattr__(self, key): + if isinstance(self._data, (str, tomlkit.items.Float, tomlkit.items.Integer)): + if key == "version": + return self._data + raise AttributeError(key) + try: + return self._data[key] + except KeyError: + pass + raise AttributeError(key) + + def __setattr__(self, key, value): + if key == "_data": + super().__setattr__(key, value) + elif key == "version" and isinstance(self._data, str): + self._data = value + else: + self._data[key] = value diff --git a/src/plette/models/scripts.py b/src/plette/models/scripts.py new file mode 100644 index 0000000..9a77b44 --- /dev/null +++ b/src/plette/models/scripts.py @@ -0,0 +1,71 @@ +import re +import shlex + +from .base import DataModel, DataValidationError + + +class Script(DataModel): + """Parse a script line (in Pipfile's [scripts] section). + + This always works in POSIX mode, even on Windows. + """ + __OPTIONAL__ = { + "script": (str,list) + } + + def __init__(self, data): + self.validate(data) + if isinstance(data, str): + data = shlex.split(data) + self._parts = data[::] + + @classmethod + def validate(cls, data): + if not data: + raise DataValidationError("Script cannot be empty") + for k, types in cls.__OPTIONAL__.items(): + if not isinstance(data, types): + raise DataValidationError(f"Invalid type for field {t}: {type(data[t])}") + def __repr__(self): + return "Script({0!r})".format(self._parts) + + @property + def command(self): + return self._parts[0] + + @property + def args(self): + return self._parts[1:] + + def cmdify(self, extra_args=None): + """Encode into a cmd-executable string. + + This re-implements CreateProcess's quoting logic to turn a list of + arguments into one single string for the shell to interpret. + + * All double quotes are escaped with a backslash. + * Existing backslashes before a quote are doubled, so they are all + escaped properly. + * Backslashes elsewhere are left as-is; cmd will interpret them + literally. + + The result is then quoted into a pair of double quotes to be grouped. + + An argument is intentionally not quoted if it does not contain + whitespaces. This is done to be compatible with Windows built-in + commands that don't work well with quotes, e.g. everything with `echo`, + and DOS-style (forward slash) switches. + + The intended use of this function is to pre-process an argument list + before passing it into ``subprocess.Popen(..., shell=True)``. + + See also: https://docs.python.org/3/library/subprocess.html + """ + parts = list(self._parts) + if extra_args: + parts.extend(extra_args) + return " ".join( + arg if not next(re.finditer(r'\s', arg), None) + else '"{0}"'.format(re.sub(r'(\\*)"', r'\1\1\\"', arg)) + for arg in parts + ) diff --git a/src/plette/models/sections.py b/src/plette/models/sections.py new file mode 100644 index 0000000..82d893b --- /dev/null +++ b/src/plette/models/sections.py @@ -0,0 +1,138 @@ +from .base import DataModel, DataModelSequence, DataModelMapping +from .hashes import Hash +from .packages import Package +from .scripts import Script +from .sources import Source + + +class PackageCollection(DataModelMapping): + item_class = Package + + +class ScriptCollection(DataModelMapping): + item_class = Script + + +class SourceCollection(DataModelSequence): + item_class = Source + + +class Requires(DataModel): + """Representation of the `[requires]` section in a Pipfile.""" + + __SCHEMA__ = {} + + __OPTIONAL__ = { + "python_version": str, + "python_full_version": str, + } + + @property + def python_version(self): + try: + return self._data["python_version"] + except KeyError: + raise AttributeError("python_version") + + @property + def python_full_version(self): + try: + return self._data["python_full_version"] + except KeyError: + raise AttributeError("python_full_version") + + +META_SECTIONS = { + "hash": Hash, + "requires": Requires, + "sources": SourceCollection, +} + + +class PipfileSection(DataModel): + + """ + Dummy pipfile validator that needs to be completed in a future PR + Hint: many pipfile features are undocumented in pipenv/project.py + """ + + @classmethod + def validate(cls, data): + pass + + +class Meta(DataModel): + """Representation of the `_meta` section in a Pipfile.lock.""" + + __SCHEMA__ = { + "hash": "dict", + "pipfile-spec": "integer", + "requires": "dict", + "sources": "list" + } + + @classmethod + def validate(cls, data): + for key, klass in META_SECTIONS.items(): + klass.validate(data[key]) + + def __getitem__(self, key): + value = super().__getitem__(key) + try: + return META_SECTIONS[key](value) + except KeyError: + return value + + def __setitem__(self, key, value): + if isinstance(value, DataModel): + self._data[key] = value._data + else: + self._data[key] = value + + @property + def hash_(self): + return self["hash"] + + @hash_.setter + def hash_(self, value): + self["hash"] = value + + @property + def hash(self): + return self["hash"] + + @hash.setter + def hash(self, value): + self["hash"] = value + + @property + def pipfile_spec(self): + return self["pipfile-spec"] + + @pipfile_spec.setter + def pipfile_spec(self, value): + self["pipfile-spec"] = value + + @property + def requires(self): + return self["requires"] + + @requires.setter + def requires(self, value): + self["requires"] = value + + @property + def sources(self): + return self["sources"] + + @sources.setter + def sources(self, value): + self["sources"] = value + + +class Pipenv(DataModel): + """Represent the [pipenv] section in Pipfile""" + __SCHEMA__ = {} + __OPTIONAL__ = { + "allow_prereleases": bool, + } diff --git a/src/plette/models/sources.py b/src/plette/models/sources.py new file mode 100644 index 0000000..95fc56a --- /dev/null +++ b/src/plette/models/sources.py @@ -0,0 +1,45 @@ +import os + +from .base import DataModel + + +class Source(DataModel): + """Information on a "simple" Python package index. + + This could be PyPI, or a self-hosted index server, etc. The server + specified by the `url` attribute is expected to provide the "simple" + package API. + """ + __SCHEMA__ = { + "name": str, + "url": str, + "verify_ssl": bool, + } + + @property + def name(self): + return self._data["name"] + + @name.setter + def name(self, value): + self._data["name"] = value + + @property + def url(self): + return self._data["url"] + + @url.setter + def url(self, value): + self._data["url"] = value + + @property + def verify_ssl(self): + return self._data["verify_ssl"] + + @verify_ssl.setter + def verify_ssl(self, value): + self._data["verify_ssl"] = value + + @property + def url_expanded(self): + return os.path.expandvars(self._data["url"]) diff --git a/src/plette/pipfiles.py b/src/plette/pipfiles.py index 914741d..aaf8da7 100644 --- a/src/plette/pipfiles.py +++ b/src/plette/pipfiles.py @@ -1,23 +1,16 @@ import hashlib import json -from dataclasses import dataclass, asdict - -from typing import Optional - import tomlkit - from .models import ( - BaseModel, - Hash, Requires, PipfileSection, Pipenv, + DataModel, Hash, Requires, PipfileSection, Pipenv, PackageCollection, ScriptCollection, SourceCollection, - remove_empty_values ) PIPFILE_SECTIONS = { - "sources": SourceCollection, + "source": SourceCollection, "packages": PackageCollection, "dev-packages": PackageCollection, "requires": Requires, @@ -33,56 +26,27 @@ verify_ssl = true """ +class Pipfile(DataModel): + """Representation of a Pipfile. + """ + __SCHEMA__ = {} -@dataclass -class Pipfile(BaseModel): - """Representation of a Pipfile.""" - sources: SourceCollection - packages: Optional[PackageCollection] = None - packages: Optional[PackageCollection] = None - dev_packages: Optional[PackageCollection] = None - requires: Optional[Requires] = None - scripts: Optional[ScriptCollection] = None - pipfile: Optional[PipfileSection] = None - pipenv: Optional[Pipenv] = None - - def validate_sources(self, value): - if isinstance(value, list): - return SourceCollection(value) - return SourceCollection(value.value) - - def validate_pipenv(self, value): - if value is not None: - return Pipenv(**value) - return value - - def validate_packages(self, value): - PackageCollection(value) - return value - - def to_dict(self): - data = { - "_meta": { - "requires": getattr(self, "requires", {}), - }, - "default": getattr(self, "packages", {}), - "develop": getattr(self, "dev-packages", {}), - } - data["_meta"].update(asdict(getattr(self, "sources", {}))) - for category, values in self.__dict__.items(): - if category in PIPFILE_SECTIONS or category in ( - "default", "develop", "pipenv"): + @classmethod + def validate(cls, data): + # HACK: DO NOT CALL `super().validate()` here!! + # Cerberus seems to break TOML Kit's inline table preservation if it + # is not at the top-level. Fortunately the spec doesn't have nested + # non-inlined tables, so we're OK as long as validation is only + # performed at section-level. validation is performed. + for key, klass in PIPFILE_SECTIONS.items(): + if key not in data: continue - data[category] = values - remove_empty_values(data) - return data + klass.validate(data[key]) - def get_hash(self): - data = self.to_dict() - content = json.dumps(data, sort_keys=True, separators=(",", ":")) - if isinstance(content, str): - content = content.encode("utf-8") - return Hash.from_hash(hashlib.sha256(content)) + package_categories = set(data.keys()) - set(PIPFILE_SECTIONS.keys()) + + for category in package_categories: + PackageCollection.validate(data[category]) @classmethod def load(cls, f, encoding=None): @@ -98,32 +62,107 @@ def load(cls, f, encoding=None): sep = "" if content.startswith("\n") else "\n" content = DEFAULT_SOURCE_TOML + sep + content data = tomlkit.loads(content) - data["sources"] = data.pop("source") - packages_sections = {} - data_sections = list(data.keys()) - for k in data_sections: - if k not in cls.__dataclass_fields__: - packages_sections[k] = data.pop(k) - - inst = cls(**data) - if packages_sections: - for k, v in packages_sections.items(): - setattr(inst, k, PackageCollection(v)) - return inst + return cls(data) - @property - def source(self): - return self.sources + def __getitem__(self, key): + value = self._data[key] + try: + return PIPFILE_SECTIONS[key](value) + except KeyError: + return value - def dump(self, f, encoding=None): - data = self.to_dict() - new_data = {} - metadata = data.pop("_meta") - new_data["source"] = metadata.pop("sources") - new_data["packages"] = data.pop("default") - new_data.update(data) - content = tomlkit.dumps(new_data) + def __setitem__(self, key, value): + if isinstance(value, DataView): + self._data[key] = value._data + else: + self._data[key] = value + def get_hash(self): + data = { + "_meta": { + "sources": self._data["source"], + "requires": self._data.get("requires", {}), + }, + "default": self._data.get("packages", {}), + "develop": self._data.get("dev-packages", {}), + } + for category, values in self._data.items(): + if category in PIPFILE_SECTIONS or category in ("default", "develop", "pipenv"): + continue + data[category] = values + content = json.dumps(data, sort_keys=True, separators=(",", ":")) + if isinstance(content, str): + content = content.encode("utf-8") + return Hash.from_hash(hashlib.sha256(content)) + + def dump(self, f, encoding=None): + content = tomlkit.dumps(self._data) if encoding is not None: content = content.encode(encoding) f.write(content) + + @property + def sources(self): + try: + return self["source"] + except KeyError: + raise AttributeError("sources") + + @sources.setter + def sources(self, value): + self["source"] = value + + @property + def source(self): + try: + return self["source"] + except KeyError: + raise AttributeError("source") + + @source.setter + def source(self, value): + self["source"] = value + + @property + def packages(self): + try: + return self["packages"] + except KeyError: + raise AttributeError("packages") + + @packages.setter + def packages(self, value): + self["packages"] = value + + @property + def dev_packages(self): + try: + return self["dev-packages"] + except KeyError: + raise AttributeError("dev-packages") + + @dev_packages.setter + def dev_packages(self, value): + self["dev-packages"] = value + + @property + def requires(self): + try: + return self["requires"] + except KeyError: + raise AttributeError("requires") + + @requires.setter + def requires(self, value): + self["requires"] = value + + @property + def scripts(self): + try: + return self["scripts"] + except KeyError: + raise AttributeError("scripts") + + @scripts.setter + def scripts(self, value): + self["scripts"] = value diff --git a/tests/integration/test_examples.py b/tests/integration/test_examples.py index 090ef11..61338d2 100644 --- a/tests/integration/test_examples.py +++ b/tests/integration/test_examples.py @@ -5,13 +5,13 @@ import plette from plette import Pipfile +from plette.models.base import DataValidationError invalid_files = glob.glob("examples/*invalid*") valid_files = glob.glob("examples/*ok*") @pytest.mark.parametrize("fname", invalid_files) def test_invalid_files(fname): - - with pytest.raises(plette.models.ValidationError): + with pytest.raises((ValueError, DataValidationError)) as excinfo: with open(fname) as f: pipfile = Pipfile.load(f) diff --git a/tests/test_lockfiles.py b/tests/test_lockfiles.py index 6d03bf2..be41269 100644 --- a/tests/test_lockfiles.py +++ b/tests/test_lockfiles.py @@ -1,17 +1,17 @@ -# pylint: disable=missing-module-docstring,missing-class-docstring -# pylint: disable=missing-function-docstring -# pylint: disable=no-member -import json +from __future__ import unicode_literals + import textwrap +import pytest + from plette import Lockfile, Pipfile -from plette.models import Package, SourceCollection, Hash, Requires +from plette.models import Package, SourceCollection HASH = "9aaf3dbaf8c4df3accd4606eb2275d3b91c9db41be4fd5a97ecc95d79a12cfe6" -def test_lockfile_load_sources(tmpdir): +def test_lockfile_load(tmpdir): fi = tmpdir.join("in.json") fi.write(textwrap.dedent( """\ @@ -44,34 +44,6 @@ def test_lockfile_load_sources(tmpdir): 'name': 'pypi', }, ]) - - -def test_lockfile_load_sources_package_spec(tmpdir): - fi = tmpdir.join("in.json") - fi.write(textwrap.dedent( - """\ - { - "_meta": { - "hash": {"sha256": "____hash____"}, - "pipfile-spec": 6, - "requires": {}, - "sources": [ - { - "name": "pypi", - "url": "https://pypi.org/simple", - "verify_ssl": true - } - ] - }, - "default": { - "flask": {"version": "*"}, - "jinja2": "*" - }, - "develop": {} - } - """, - ).replace("____hash____", HASH)) - lock = Lockfile.load(fi) assert lock.default["jinja2"] == Package("*") @@ -112,14 +84,13 @@ def test_lockfile_dump_format(tmpdir): outpath = tmpdir.join("out.json") with outpath.open("w") as f: lock.dump(f) - loaded = json.loads(outpath.read()) - assert "_meta" in loaded - assert json.loads(outpath.read()) == json.loads(content) + + assert outpath.read() == content def test_lockfile_from_pipfile_meta(): - pipfile = Pipfile(**{ - "sources": [ + pipfile = Pipfile({ + "source": [ { "name": "pypi", "url": "https://pypi.org/simple", @@ -130,23 +101,22 @@ def test_lockfile_from_pipfile_meta(): "python_version": "3.7", } }) - pipfile_hash_value = pipfile.get_hash().value lockfile = Lockfile.with_meta_from(pipfile) - pipfile.requires["python_version"] = "3.8" - pipfile.sources.sources.append({ + pipfile.requires._data["python_version"] = "3.8" + pipfile.sources.append({ "name": "devpi", "url": "http://localhost/simple", "verify_ssl": True, }) - assert lockfile.meta.hash == Hash.from_dict({"sha256": pipfile_hash_value}) - assert lockfile.meta.requires == Requires(python_version={'python_version': '3.7'}, python_full_version=None) - assert lockfile.meta.sources == SourceCollection([ + assert lockfile.meta.hash._data == {"sha256": pipfile_hash_value} + assert lockfile.meta.requires._data == {"python_version": "3.7"} + assert lockfile.meta.sources._data == [ { "name": "pypi", "url": "https://pypi.org/simple", "verify_ssl": True, }, - ]) + ] diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000..cc6ed49 --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,243 @@ +import hashlib + +import tomlkit.items +import pytest + +from plette import models + + +def test_hash_from_hash(): + v = hashlib.md5(b"foo") + h = models.Hash.from_hash(v) + assert h.name == "md5" + assert h.value == "acbd18db4cc2f85cedef654fccc4a4d8" + + +def test_hash_from_line(): + h = models.Hash.from_line("md5:acbd18db4cc2f85cedef654fccc4a4d8") + assert h.name == "md5" + assert h.value == "acbd18db4cc2f85cedef654fccc4a4d8" + + +def test_hash_as_line(): + h = models.Hash({"md5": "acbd18db4cc2f85cedef654fccc4a4d8"}) + assert h.as_line() == "md5:acbd18db4cc2f85cedef654fccc4a4d8" + + +def test_source_from_data(): + s = models.Source( + { + "name": "devpi", + "url": "https://$USER:$PASS@mydevpi.localhost", + "verify_ssl": False, + } + ) + assert s.name == "devpi" + assert s.url == "https://$USER:$PASS@mydevpi.localhost" + assert s.verify_ssl is False + + +def test_source_as_data_expanded(monkeypatch): + monkeypatch.setattr("os.environ", {"USER": "user", "PASS": "pa55"}) + s = models.Source( + { + "name": "devpi", + "url": "https://$USER:$PASS@mydevpi.localhost", + "verify_ssl": False, + } + ) + assert s.url_expanded == "https://user:pa55@mydevpi.localhost" + + +def test_source_as_data_expanded_partial(monkeypatch): + monkeypatch.setattr("os.environ", {"USER": "user"}) + s = models.Source( + { + "name": "devpi", + "url": "https://$USER:$PASS@mydevpi.localhost", + "verify_ssl": False, + } + ) + assert s.url_expanded == "https://user:$PASS@mydevpi.localhost" + + +def test_requires_python_version(): + r = models.Requires({"python_version": "8.19"}) + assert r.python_version == "8.19" + + +def test_requires_python_version_no_full_version(): + r = models.Requires({"python_version": "8.19"}) + with pytest.raises(AttributeError) as ctx: + r.python_full_version + assert str(ctx.value) == "python_full_version" + + +def test_requires_python_full_version(): + r = models.Requires({"python_full_version": "8.19"}) + assert r.python_full_version == "8.19" + + +def test_requires_python_full_version_no_version(): + r = models.Requires({"python_full_version": "8.19"}) + with pytest.raises(AttributeError) as ctx: + r.python_version + assert str(ctx.value) == "python_version" + + +def test_allows_python_version_and_full(): + r = models.Requires({"python_version": "8.1", "python_full_version": "8.1.9"}) + assert r.python_version == "8.1" + assert r.python_full_version == "8.1.9" + + +def test_package_str(): + p = models.Package("*") + p.version == "*" + + +def test_package_dict(): + p = models.Package({"version": "*"}) + p.version == "*" + + +def test_package_wrong_key(): + p = models.Package({"path": ".", "editable": True}) + assert p.editable is True + with pytest.raises(AttributeError) as ctx: + p.version + assert str(ctx.value) == "version" + + +def test_package_with_wrong_specfiers(): + with pytest.raises(models.base.DataValidationError) as ctx: + _ = models.Package(1.2) + assert str(ctx.value) == "invalid type for package data: " + + +def test_package_with_specfiers(): + value = 1.2 + float_value = tomlkit.items.Float(value, tomlkit.items.Trivia(), str(value)) + p = models.Package(float_value) + assert p.version == float_value + + +def test_package_with_wrong_extras(): + with pytest.raises(models.base.DataValidationError): + _ = models.Package({"version": "==1.20.0", "extras": "broker"}) + + +def test_package_with_extras(): + p = models.Package({"version": "==1.20.0", "extras": ["broker", "tests"]}) + assert p.extras == ['broker', 'tests'] + + +HASH = "9aaf3dbaf8c4df3accd4606eb2275d3b91c9db41be4fd5a97ecc95d79a12cfe6" + + +def test_meta(): + m = models.Meta( + { + "hash": {"sha256": HASH}, + "pipfile-spec": 6, + "requires": {}, + "sources": [ + { + "name": "pypi", + "url": "https://pypi.org/simple", + "verify_ssl": True, + }, + ], + } + ) + assert m.hash.name == "sha256" + + +@pytest.fixture() +def sources(): + return models.SourceCollection( + [ + { + "name": "pypi", + "url": "https://pypi.org/simple", + "verify_ssl": True, + }, + { + "name": "devpi", + "url": "http://127.0.0.1:$DEVPI_PORT/simple", + "verify_ssl": True, + }, + ] + ) + + +def test_get_slice(sources): + sliced = sources[:1] + assert isinstance(sliced, models.SourceCollection) + assert len(sliced) == 1 + assert sliced[0] == models.Source( + { + "name": "pypi", + "url": "https://pypi.org/simple", + "verify_ssl": True, + } + ) + + +def test_set_slice(sources): + sources[1:] = [ + { + "name": "localpi-4433", + "url": "https://127.0.0.1:4433/simple", + "verify_ssl": False, + }, + { + "name": "localpi-8000", + "url": "http://127.0.0.1:8000/simple", + "verify_ssl": True, + }, + ] + assert sources._data == [ + { + "name": "pypi", + "url": "https://pypi.org/simple", + "verify_ssl": True, + }, + { + "name": "localpi-4433", + "url": "https://127.0.0.1:4433/simple", + "verify_ssl": False, + }, + { + "name": "localpi-8000", + "url": "http://127.0.0.1:8000/simple", + "verify_ssl": True, + }, + ] + + +def test_del_slice(sources): + del sources[:1] + assert sources._data == [ + { + "name": "devpi", + "url": "http://127.0.0.1:$DEVPI_PORT/simple", + "verify_ssl": True, + }, + ] + + +def test_validation_error(): + data = {"name": "test", "url": "https://pypi.org/simple", "verify_ssl": 1} + with pytest.raises(models.base.DataValidationError) as exc_info: + models.Source.validate(data) + + error_message = str(exc_info.value) + assert "Invalid type for field verify_ssl: " in error_message + + data = {"name": "test", "verify_ssl": False} + with pytest.raises(models.base.DataValidationError) as exc_info: + models.Source.validate(data) + + error_message = str(exc_info.value) + assert "Missing required field: url" in error_message diff --git a/tests/test_models_hash.py b/tests/test_models_hash.py deleted file mode 100644 index 99b4f8b..0000000 --- a/tests/test_models_hash.py +++ /dev/null @@ -1,20 +0,0 @@ -import hashlib - -from plette.models import Hash - -def test_hash_from_hash(): - v = hashlib.md5(b"foo") - h = Hash.from_hash(v) - assert h.name == "md5" - assert h.value == "acbd18db4cc2f85cedef654fccc4a4d8" - - -def test_hash_from_line(): - h = Hash.from_line("md5:acbd18db4cc2f85cedef654fccc4a4d8") - assert h.name == "md5" - assert h.value == "acbd18db4cc2f85cedef654fccc4a4d8" - - -def test_hash_as_line(): - h = Hash.from_dict({"md5": "acbd18db4cc2f85cedef654fccc4a4d8"}) - assert h.as_line() == "md5:acbd18db4cc2f85cedef654fccc4a4d8" diff --git a/tests/test_models_meta.py b/tests/test_models_meta.py deleted file mode 100644 index 20f9631..0000000 --- a/tests/test_models_meta.py +++ /dev/null @@ -1,20 +0,0 @@ -from plette.models import Meta - -HASH = "9aaf3dbaf8c4df3accd4606eb2275d3b91c9db41be4fd5a97ecc95d79a12cfe6" - -def test_meta(): - m = Meta.from_dict( - { - "hash": {"sha256": HASH}, - "pipfile-spec": 6, - "requires": {}, - "sources": [ - { - "name": "pypi", - "url": "https://pypi.org/simple", - "verify_ssl": True, - }, - ], - } - ) - assert m.hash.name == "sha256" diff --git a/tests/test_models_packages.py b/tests/test_models_packages.py deleted file mode 100644 index 91b973d..0000000 --- a/tests/test_models_packages.py +++ /dev/null @@ -1,39 +0,0 @@ -import pytest - -from plette.models import Package - -def test_package_str(): - p = Package("*") - assert p.version == "*" - - -def test_package_dict(): - p = Package({"version": "*"}) - assert p.version == {"version": "*"} - - -def test_package_version_is_none(): - p = Package(**{"path": ".", "editable": True}) - assert p.version == "*" - assert p.editable is True - -def test_package_with_wrong_extras(): - with pytest.raises(ValueError): - p = Package(**{"version": "==1.20.0", "extras": "broker"}) - - with pytest.raises(ValueError): - p = Package(**{"version": "==1.20.0", "extras": ["broker", {}]}) - - with pytest.raises(ValueError): - p = Package(**{"version": "==1.20.0", "extras": ["broker", 1]}) - - -def test_package_with_extras(): - p = Package(**{"version": "==1.20.0", "extras": ["broker", "tests"]}) - assert p.extras == ['broker', 'tests'] - - -def test_package_wrong_key(): - p = Package(**{"path": ".", "editable": True}) - assert p.editable is True - assert p.version is "*" diff --git a/tests/test_models_requires.py b/tests/test_models_requires.py deleted file mode 100644 index 814b7a4..0000000 --- a/tests/test_models_requires.py +++ /dev/null @@ -1,27 +0,0 @@ -import pytest -from plette import models - -def test_requires_python_version(): - r = models.Requires(**{"python_version": "8.19"}) - assert r.python_version == "8.19" - - -def test_requires_python_version_no_full_version(): - r = models.Requires(**{"python_version": "8.19"}) - r.python_full_version is None - - -def test_requires_python_full_version(): - r = models.Requires(**{"python_full_version": "8.19"}) - assert r.python_full_version == "8.19" - - -def test_requires_python_full_version_no_version(): - r = models.Requires(**{"python_full_version": "8.19"}) - r.python_version is None - - -def test_allows_python_version_and_full(): - r = models.Requires(**{"python_version": "8.1", "python_full_version": "8.1.9"}) - assert r.python_version == "8.1" - assert r.python_full_version == "8.1.9" diff --git a/tests/test_models_sourcecollections.py b/tests/test_models_sourcecollections.py deleted file mode 100644 index 54165c9..0000000 --- a/tests/test_models_sourcecollections.py +++ /dev/null @@ -1,81 +0,0 @@ -import hashlib - -import pytest - -from plette import models -from plette.models import Source, SourceCollection - - -@pytest.fixture() -def sources(): - return models.SourceCollection( - [ - { - "name": "pypi", - "url": "https://pypi.org/simple", - "verify_ssl": True, - }, - { - "name": "devpi", - "url": "http://127.0.0.1:$DEVPI_PORT/simple", - "verify_ssl": True, - }, - ] - ) - - -def test_get_slice(sources): - sliced = sources[:1] - assert isinstance(sliced, models.SourceCollection) - assert len(sliced) == 1 - assert sliced[0] == models.Source( - **{ - "name": "pypi", - "url": "https://pypi.org/simple", - "verify_ssl": True, - } - ) - - -def test_set_slice(sources): - sources[1:] = [ - Source(**{ - "name": "localpi-4433", - "url": "https://127.0.0.1:4433/simple", - "verify_ssl": False, - }), - Source(**{ - "name": "localpi-8000", - "url": "http://127.0.0.1:8000/simple", - "verify_ssl": True, - }), - ] - assert sources == \ - SourceCollection([ - Source(**{ - "name": "pypi", - "url": "https://pypi.org/simple", - "verify_ssl": True, - }), - Source(**{ - "name": "localpi-4433", - "url": "https://127.0.0.1:4433/simple", - "verify_ssl": False, - }), - Source(**{ - "name": "localpi-8000", - "url": "http://127.0.0.1:8000/simple", - "verify_ssl": True, - }), - ]) - - -def test_del_slice(sources): - del sources[:1] - assert sources == SourceCollection([ - Source(**{ - "name": "devpi", - "url": "http://127.0.0.1:$DEVPI_PORT/simple", - "verify_ssl": True, - }), - ]) diff --git a/tests/test_models_sources.py b/tests/test_models_sources.py deleted file mode 100644 index f0ae440..0000000 --- a/tests/test_models_sources.py +++ /dev/null @@ -1,60 +0,0 @@ -import pytest -from plette.models import Source -from plette import models - -def test_source_from_data(): - s = Source( - **{ - "name": "devpi", - "url": "https://$USER:$PASS@mydevpi.localhost", - "verify_ssl": False, - } - ) - assert s.name == "devpi" - assert s.url == "https://$USER:$PASS@mydevpi.localhost" - assert s.verify_ssl is False - - -def test_source_as_data_expanded(monkeypatch): - monkeypatch.setattr("os.environ", {"USER": "user", "PASS": "pa55"}) - s = Source( - **{ - "name": "devpi", - "url": "https://$USER:$PASS@mydevpi.localhost", - "verify_ssl": False, - } - ) - assert s.url_expanded == "https://user:pa55@mydevpi.localhost" - - -def test_source_as_data_expanded_partial(monkeypatch): - monkeypatch.setattr("os.environ", {"USER": "user"}) - s = Source( - **{ - "name": "devpi", - "url": "https://$USER:$PASS@mydevpi.localhost", - "verify_ssl": False, - } - ) - assert s.url_expanded == "https://user:$PASS@mydevpi.localhost" - - -def test_validation_error(): - data = {"name": "test", "verify_ssl": 1} - - with pytest.raises(TypeError) as exc_info: - Source(**data) - - error_message = str(exc_info.value) - - assert "missing 1 required positional argument: 'url'" in error_message - - data["url"] = "http://localhost:8000" - - with pytest.raises(models.ValidationError) as exc_info: - Source(**data) - - error_message = str(exc_info.value) - - - assert "verify_ssl: must be of boolean type" in error_message diff --git a/tests/test_pipfiles.py b/tests/test_pipfiles.py index 203b2e5..29e0493 100644 --- a/tests/test_pipfiles.py +++ b/tests/test_pipfiles.py @@ -1,5 +1,7 @@ import textwrap +import pytest + from plette import Pipfile from plette.models import PackageCollection, SourceCollection @@ -25,13 +27,13 @@ def test_source_section_transparent(): }, ]) section[0].verify_ssl = True - assert section == SourceCollection([ + assert section._data == [ { "name": "devpi", "url": "https://$USER:$PASS@mydevpi.localhost", "verify_ssl": True, }, - ]) + ] def test_package_section(): @@ -39,7 +41,10 @@ def test_package_section(): "flask": {"version": "*"}, "jinja2": "*", }) - assert section.packages["jinja2"].version == "*" + assert section["jinja2"].version == "*" + with pytest.raises(KeyError) as ctx: + section["mosql"] + assert str(ctx.value) == repr("mosql") def test_pipfile_load(tmpdir): @@ -50,18 +55,17 @@ def test_pipfile_load(tmpdir): jinja2 = '*' # A comment. """)) p = Pipfile.load(fi) - - assert p.source == SourceCollection([ + assert p["source"] == SourceCollection([ { 'url': 'https://pypi.org/simple', 'verify_ssl': True, 'name': 'pypi', }, ]) - assert p.packages == { + assert p["packages"] == PackageCollection({ "flask": {"version": "*"}, "jinja2": "*", - } + }) def test_pipfile_preserve_format(tmpdir): @@ -73,17 +77,17 @@ def test_pipfile_preserve_format(tmpdir): jinja2 = '*' """, )) - pf= Pipfile.load(fi) - pf.source[0].verify_ssl = False + p = Pipfile.load(fi) + p["source"][0].verify_ssl = False fo = tmpdir.join("Pipfile.out") - pf.dump(fo) + p.dump(fo) assert fo.read() == textwrap.dedent( """\ [[source]] name = "pypi" - verify_ssl = false url = "https://pypi.org/simple" + verify_ssl = false [packages] flask = { version = "*" } diff --git a/tests/test_scripts.py b/tests/test_scripts.py index 1087a50..a0864e6 100644 --- a/tests/test_scripts.py +++ b/tests/test_scripts.py @@ -1,5 +1,6 @@ import pytest +from plette.models.base import DataValidationError from plette.models import Script @@ -10,9 +11,12 @@ def test_parse(): def test_parse_error(): - with pytest.raises(IndexError): + with pytest.raises(DataValidationError) as ctx: Script('') + assert str(ctx.value) == "Script cannot be empty", ctx + + def test_cmdify(): script = Script(['python', '-c', "print('hello world')"]) cmd = script.cmdify()