Skip to content

live_server: fix in-memory DB detection, dj 2.2 compat #793

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
9 changes: 6 additions & 3 deletions pytest_django/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import pytest

from . import live_server_helper
from .django_compat import is_django_unittest
from .lazy_django import skip_if_no_django

Expand Down Expand Up @@ -385,6 +384,7 @@ def live_server(request):
"""
skip_if_no_django()

from . import live_server_helper
import django

addr = request.config.getvalue("liveserver") or os.getenv(
Expand Down Expand Up @@ -434,8 +434,11 @@ def _live_server_helper(request):
request.getfixturevalue("transactional_db")

live_server = request.getfixturevalue("live_server")
live_server._live_server_modified_settings.enable()
request.addfinalizer(live_server._live_server_modified_settings.disable)

modified_settings = live_server._dj_testcase._live_server_modified_settings
if not hasattr(modified_settings, "wrapped"):
modified_settings.enable()
request.addfinalizer(modified_settings.disable)


@contextmanager
Expand Down
86 changes: 10 additions & 76 deletions pytest_django/live_server_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,66 +10,28 @@ class LiveServer(object):
"""

def __init__(self, addr):
import django
from django.db import connections
from django.test.testcases import LiveServerThread
from django.test.utils import modify_settings

connections_override = {}
for conn in connections.all():
# If using in-memory sqlite databases, pass the connections to
# the server thread.
if (
conn.settings_dict["ENGINE"] == "django.db.backends.sqlite3"
and conn.settings_dict["NAME"] == ":memory:"
):
# Explicitly enable thread-shareability for this connection
conn.allow_thread_sharing = True
connections_override[conn.alias] = conn

liveserver_kwargs = {"connections_override": connections_override}
from django.test.testcases import LiveServerTestCase
from django.conf import settings

if "django.contrib.staticfiles" in settings.INSTALLED_APPS:
from django.contrib.staticfiles.handlers import StaticFilesHandler

liveserver_kwargs["static_handler"] = StaticFilesHandler
else:
from django.test.testcases import _StaticFilesHandler

liveserver_kwargs["static_handler"] = _StaticFilesHandler

if django.VERSION < (1, 11):
host, possible_ports = parse_addr(addr)
self.thread = LiveServerThread(host, possible_ports, **liveserver_kwargs)
else:
try:
host, port = addr.split(":")
except ValueError:
host = addr
else:
liveserver_kwargs["port"] = int(port)
self.thread = LiveServerThread(host, **liveserver_kwargs)
from django.test.testcases import _StaticFilesHandler as StaticFilesHandler

self._live_server_modified_settings = modify_settings(
ALLOWED_HOSTS={"append": host}
)
class CustomLiveServerTestCase(LiveServerTestCase):
static_handler = StaticFilesHandler

self.thread.daemon = True
self.thread.start()
self.thread.is_ready.wait()

if self.thread.error:
raise self.thread.error
self._dj_testcase = CustomLiveServerTestCase("__init__")
self._dj_testcase.setUpClass()

def stop(self):
"""Stop the server"""
self.thread.terminate()
self.thread.join()
if not hasattr(self._dj_testcase._live_server_modified_settings, "wrapped"):
self._dj_testcase._live_server_modified_settings.enable()
self._dj_testcase.tearDownClass()

@property
def url(self):
return "http://%s:%s" % (self.thread.host, self.thread.port)
return self._dj_testcase.live_server_url

def __str__(self):
return self.url
Expand All @@ -79,31 +41,3 @@ def __add__(self, other):

def __repr__(self):
return "<LiveServer listening at %s>" % self.url


def parse_addr(specified_address):
"""Parse the --liveserver argument into a host/IP address and port range"""
# This code is based on
# django.test.testcases.LiveServerTestCase.setUpClass

# The specified ports may be of the form '8000-8010,8080,9200-9300'
# i.e. a comma-separated list of ports or ranges of ports, so we break
# it down into a detailed list of all possible ports.
possible_ports = []
try:
host, port_ranges = specified_address.split(":")
for port_range in port_ranges.split(","):
# A port range can be of either form: '8000' or '8000-8010'.
extremes = list(map(int, port_range.split("-")))
assert len(extremes) in (1, 2)
if len(extremes) == 1:
# Port range of the form '8000'
possible_ports.append(extremes[0])
else:
# Port range of the form '8000-8010'
for port in range(extremes[0], extremes[1] + 1):
possible_ports.append(port)
except Exception:
raise Exception('Invalid address ("%s") for live server.' % specified_address)

return host, possible_ports