Source code for vcr.config

import copy
import functools
import inspect
import os
import types
from collections import abc as collections_abc
from pathlib import Path

from . import filters, matchers
from .cassette import Cassette
from .persisters.filesystem import FilesystemPersister
from .record_mode import RecordMode
from .serializers import jsonserializer, yamlserializer
from .util import auto_decorate, compose


[docs] class VCR:
[docs] @staticmethod def is_test_method(method_name, function): return method_name.startswith("test") and isinstance(function, types.FunctionType)
[docs] @staticmethod def ensure_suffix(suffix): def ensure(path): if not path.endswith(suffix): return path + suffix return path return ensure
[docs] def __init__( self, path_transformer=None, before_record_request=None, custom_patches=(), filter_query_parameters=(), ignore_hosts=(), record_mode=RecordMode.ONCE, ignore_localhost=False, filter_headers=(), before_record_response=None, filter_post_data_parameters=(), match_on=("method", "scheme", "host", "port", "path", "query"), before_record=None, inject_cassette=False, serializer="yaml", cassette_library_dir=None, func_path_generator=None, decode_compressed_response=False, record_on_exception=True, ): self.serializer = serializer self.match_on = match_on self.cassette_library_dir = cassette_library_dir self.serializers = {"yaml": yamlserializer, "json": jsonserializer} self.matchers = { "method": matchers.method, "uri": matchers.uri, "url": matchers.uri, # matcher for backwards compatibility "scheme": matchers.scheme, "host": matchers.host, "port": matchers.port, "path": matchers.path, "query": matchers.query, "headers": matchers.headers, "raw_body": matchers.raw_body, "body": matchers.body, } self.persister = FilesystemPersister self.record_mode = record_mode self.filter_headers = filter_headers self.filter_query_parameters = filter_query_parameters self.filter_post_data_parameters = filter_post_data_parameters self.before_record_request = before_record_request or before_record self.before_record_response = before_record_response self.ignore_hosts = ignore_hosts self.ignore_localhost = ignore_localhost self.inject_cassette = inject_cassette self.path_transformer = path_transformer self.func_path_generator = func_path_generator self.decode_compressed_response = decode_compressed_response self.record_on_exception = record_on_exception self._custom_patches = tuple(custom_patches)
def _get_serializer(self, serializer_name): try: serializer = self.serializers[serializer_name] except KeyError: raise KeyError(f"Serializer {serializer_name} doesn't exist or isn't registered") from None return serializer def _get_matchers(self, matcher_names): matchers = [] try: for m in matcher_names: matchers.append(self.matchers[m]) except KeyError: raise KeyError(f"Matcher {m} doesn't exist or isn't registered") from None return matchers
[docs] def use_cassette(self, path=None, **kwargs): if path is not None and not isinstance(path, (str, Path)): function = path # Assume this is an attempt to decorate a function return self._use_cassette(**kwargs)(function) return self._use_cassette(path=path, **kwargs)
def _use_cassette(self, with_current_defaults=False, **kwargs): if with_current_defaults: config = self.get_merged_config(**kwargs) return Cassette.use(**config) # This is made a function that evaluates every time a cassette # is made so that changes that are made to this VCR instance # that occur AFTER the `use_cassette` decorator is applied # still affect subsequent calls to the decorated function. args_getter = functools.partial(self.get_merged_config, **kwargs) return Cassette.use_arg_getter(args_getter)
[docs] def get_merged_config(self, **kwargs): serializer_name = kwargs.get("serializer", self.serializer) matcher_names = kwargs.get("match_on", self.match_on) path_transformer = kwargs.get("path_transformer", self.path_transformer) func_path_generator = kwargs.get("func_path_generator", self.func_path_generator) cassette_library_dir = kwargs.get("cassette_library_dir", self.cassette_library_dir) additional_matchers = kwargs.get("additional_matchers", ()) record_on_exception = kwargs.get("record_on_exception", self.record_on_exception) if cassette_library_dir: def add_cassette_library_dir(path): if not path.startswith(cassette_library_dir): return os.path.join(cassette_library_dir, path) return path path_transformer = compose(add_cassette_library_dir, path_transformer) elif not func_path_generator: # If we don't have a library dir, use the functions # location to build a full path for cassettes. func_path_generator = self._build_path_from_func_using_module merged_config = { "serializer": self._get_serializer(serializer_name), "persister": self.persister, "match_on": self._get_matchers(tuple(matcher_names) + tuple(additional_matchers)), "record_mode": kwargs.get("record_mode", self.record_mode), "before_record_request": self._build_before_record_request(kwargs), "before_record_response": self._build_before_record_response(kwargs), "custom_patches": self._custom_patches + kwargs.get("custom_patches", ()), "inject": kwargs.get("inject_cassette", self.inject_cassette), "path_transformer": path_transformer, "func_path_generator": func_path_generator, "allow_playback_repeats": kwargs.get("allow_playback_repeats", False), "record_on_exception": record_on_exception, } path = kwargs.get("path") if path: merged_config["path"] = path return merged_config
def _build_before_record_response(self, options): before_record_response = options.get("before_record_response", self.before_record_response) decode_compressed_response = options.get( "decode_compressed_response", self.decode_compressed_response, ) filter_functions = [] if decode_compressed_response: filter_functions.append(filters.decode_response) if before_record_response: if not isinstance(before_record_response, collections_abc.Iterable): before_record_response = (before_record_response,) filter_functions.extend(before_record_response) def before_record_response(response): for function in filter_functions: if response is None: break response = function(response) return response return before_record_response def _build_before_record_request(self, options): filter_functions = [] filter_headers = options.get("filter_headers", self.filter_headers) filter_query_parameters = options.get("filter_query_parameters", self.filter_query_parameters) filter_post_data_parameters = options.get( "filter_post_data_parameters", self.filter_post_data_parameters, ) before_record_request = options.get( "before_record_request", options.get("before_record", self.before_record_request), ) ignore_hosts = options.get("ignore_hosts", self.ignore_hosts) ignore_localhost = options.get("ignore_localhost", self.ignore_localhost) if filter_headers: replacements = [h if isinstance(h, tuple) else (h, None) for h in filter_headers] filter_functions.append(functools.partial(filters.replace_headers, replacements=replacements)) if filter_query_parameters: replacements = [p if isinstance(p, tuple) else (p, None) for p in filter_query_parameters] filter_functions.append( functools.partial(filters.replace_query_parameters, replacements=replacements), ) if filter_post_data_parameters: replacements = [p if isinstance(p, tuple) else (p, None) for p in filter_post_data_parameters] filter_functions.append( functools.partial(filters.replace_post_data_parameters, replacements=replacements), ) hosts_to_ignore = set(ignore_hosts) if ignore_localhost: hosts_to_ignore.update(("localhost", "0.0.0.0", "127.0.0.1")) if hosts_to_ignore: filter_functions.append(self._build_ignore_hosts(hosts_to_ignore)) if before_record_request: if not isinstance(before_record_request, collections_abc.Iterable): before_record_request = (before_record_request,) filter_functions.extend(before_record_request) def before_record_request(request): request = copy.deepcopy(request) for function in filter_functions: if request is None: break request = function(request) return request return before_record_request @staticmethod def _build_ignore_hosts(hosts_to_ignore): def filter_ignored_hosts(request): if hasattr(request, "host") and request.host in hosts_to_ignore: return return request return filter_ignored_hosts @staticmethod def _build_path_from_func_using_module(function): return os.path.join(os.path.dirname(inspect.getfile(function)), function.__name__)
[docs] def register_serializer(self, name, serializer): self.serializers[name] = serializer
[docs] def register_matcher(self, name, matcher): self.matchers[name] = matcher
[docs] def register_persister(self, persister): # Singleton, no name required self.persister = persister
[docs] def test_case(self, predicate=None): predicate = predicate or self.is_test_method metaclass = auto_decorate(self.use_cassette, predicate) return metaclass("temporary_class", (), {})