Source code for torchwrench.extras.yaml.yaml

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import io
from argparse import Namespace
from pathlib import Path
from typing import TYPE_CHECKING, Any, Iterable, Literal, Mapping, Optional, Union

import lazy_loader as lazy
from pythonwrench.functools import function_alias
from pythonwrench.typing import DataclassInstance, NamedTupleInstance

from torchwrench.core.packaging import _OMEGACONF_AVAILABLE, _YAML_AVAILABLE
from torchwrench.serialization.common import as_builtin

from .definitions import (
    MappingNode,
    Node,
    ParserError,
    SafeLoader,
    ScalarNode,
    ScannerError,
    SequenceNode,
    YamlLoaders,
    yaml,
)

if TYPE_CHECKING:
    import omegaconf
else:
    omegaconf = lazy.load("omegaconf", require="omegaconf")


[docs] def dump_yaml( data: Union[ Iterable[Any], Mapping[str, Any], Namespace, DataclassInstance, NamedTupleInstance, ], fpath: Union[str, Path, None] = None, *, overwrite: bool = True, to_builtins: bool = False, make_parents: bool = True, resolve: bool = False, encoding: Optional[str] = "utf-8", # YAML dump kwargs sort_keys: bool = False, indent: Union[int, None] = None, width: Union[int, None] = 1000, allow_unicode: bool = True, **yaml_dump_kwds, ) -> str: """Dump content to yaml format.""" if not _YAML_AVAILABLE: msg = f"Cannot use python module {__file__} since pyyaml package is not installed. Please install it with `pip install torchwrench[extras]`." raise ImportError(msg) if not _OMEGACONF_AVAILABLE and resolve: msg = ( "Cannot resolve yaml config without omegaconf package." "Please use resolve=False or install omegaconf with `pip install torchwrench[extras]`." ) raise ValueError(msg) if fpath is not None: fpath = Path(fpath).resolve().expanduser() if not overwrite and fpath.exists(): raise FileExistsError(f"File {fpath} already exists.") elif make_parents: fpath.parent.mkdir(parents=True, exist_ok=True) if resolve: OmegaConf = omegaconf.OmegaConf data = OmegaConf.create(data) # type: ignore data = OmegaConf.to_container(data, resolve=True) # type: ignore if to_builtins: data = as_builtin(data) content = yaml.dump( data, sort_keys=sort_keys, indent=indent, width=width, allow_unicode=allow_unicode, **yaml_dump_kwds, ) if fpath is not None: fpath.write_text(content, encoding=encoding) return content
[docs] @function_alias(dump_yaml) def dumps_yaml(*args, **kwargs): ...
[docs] @function_alias(dump_yaml) def save_yaml(*args, **kwargs): ...
[docs] def load_yaml( file: Union[str, Path, io.TextIOBase], *, Loader: YamlLoaders = SafeLoader, on_error: Literal["raise", "ignore"] = "raise", ) -> Any: """Load YAML from filepath or opened file.""" if not _YAML_AVAILABLE: msg = f"Cannot use python module {__file__} since pyyaml package is not installed. Please install it with `pip install torchwrench[extras]`." raise ImportError(msg) if isinstance(file, (str, Path)): with open(file, "r") as buffer: return loads_yaml(buffer, Loader=Loader, on_error=on_error) elif isinstance(file, io.TextIOBase): return loads_yaml(file, Loader=Loader, on_error=on_error) else: msg = f"Invalid argument type {type(file)}." raise TypeError(msg)
[docs] def loads_yaml( content: Union[str, io.TextIOBase], *, Loader: YamlLoaders = SafeLoader, on_error: Literal["raise", "ignore"] = "raise", ) -> Any: """Load YAML from string and text-io stream.""" if isinstance(content, str): with io.StringIO(content) as buffer: return loads_yaml(buffer, Loader=Loader, on_error=on_error) try: data = yaml.load(content, Loader=Loader) # type: ignore except (ScannerError, ParserError) as err: if on_error == "ignore": return None else: raise err return data
[docs] @function_alias(load_yaml) def read_yaml(*args, **kwargs): ...
[docs] class IgnoreTagLoader(SafeLoader): # type: ignore """SafeLoader that ignores yaml tags. Examples ======== ```python >>> dumped = "a: !!python/tuple\n- 1\n- 2" >>> yaml.load(dumped, Loader=IgnoreTagLoader) ... {"a": [1, 2]} >>> yaml.load(dumped, Loader=FullLoader) ... {"a": (1, 2)} >>> yaml.load(dumped, Loader=SafeLoader) # raises ConstructorError ``` """
[docs] def construct_with_tag(self, tag: str, node: Node) -> Any: if isinstance(node, MappingNode): return self.construct_mapping(node) elif isinstance(node, ScalarNode): return self.construct_scalar(node) elif isinstance(node, SequenceNode): return self.construct_sequence(node) else: msg = f"Unsupported node type {type(node)} with {tag=}." raise NotImplementedError(msg)
[docs] class SplitTagLoader(SafeLoader): # type: ignore """SafeLoader that store tags inside value. Examples ======== ```python >>> dumped = "a: !!python/tuple\n- 1\n- 2" >>> yaml.load(dumped, Loader=SplitTagLoader) ... {'a': {'_target_': 'yaml.org,2002:python/tuple', '_args_': [1, 2]}} ``` """ def __init__( self, stream, *, tag_key: str = "_target_", args_key: str = "_args_", ) -> None: super().__init__(stream) self.tag_key = tag_key self.args_key = args_key
[docs] def construct_with_tag(self, tag: str, node: Node) -> Any: if isinstance(node, MappingNode): result = self.construct_mapping(node) elif isinstance(node, ScalarNode): result = self.construct_scalar(node) elif isinstance(node, SequenceNode): result = self.construct_sequence(node) else: msg = f"Unsupported node type {type(node)} with {tag=}." raise NotImplementedError(msg) result = { self.tag_key: tag, self.args_key: result, } return result
IgnoreTagLoader.add_multi_constructor("!", IgnoreTagLoader.construct_with_tag) # type: ignore IgnoreTagLoader.add_multi_constructor("tag:", IgnoreTagLoader.construct_with_tag) # type: ignore SplitTagLoader.add_multi_constructor("!", SplitTagLoader.construct_with_tag) # type: ignore SplitTagLoader.add_multi_constructor("tag:", SplitTagLoader.construct_with_tag) # type: ignore