diff --git a/flask_pymongo/__init__.py b/flask_pymongo/__init__.py index 72aef75..03ebb81 100644 --- a/flask_pymongo/__init__.py +++ b/flask_pymongo/__init__.py @@ -123,7 +123,12 @@ def init_app(self, app: Flask, uri: str | None = None, *args: Any, **kwargs: Any # view helpers def send_file( - self, filename: str, base: str = "fs", version: int = -1, cache_for: int = 31536000 + self, + filename: str, + base: str = "fs", + version: int = -1, + cache_for: int = 31536000, + db: str | None = None, ) -> Response: """Respond with a file from GridFS. @@ -144,6 +149,7 @@ def get_upload(filename): revision. If no such version exists, return with HTTP status 404. :param int cache_for: number of seconds that browsers should be instructed to cache responses + :param str db: the target database, if different from the default database. """ if not isinstance(base, str): raise TypeError("'base' must be string or unicode") @@ -152,8 +158,13 @@ def get_upload(filename): if not isinstance(cache_for, int): raise TypeError("'cache_for' must be an integer") - assert self.db is not None, "Please initialize the app before calling send_file!" - storage = GridFS(self.db, base) + if db: + db_obj = self.cx[db] + else: + db_obj = self.db + + assert db_obj is not None, "Please initialize the app before calling send_file!" + storage = GridFS(db_obj, base) try: fileobj = storage.get_version(filename=filename, version=version) @@ -189,6 +200,7 @@ def save_file( fileobj: Any, base: str = "fs", content_type: str | None = None, + db: str | None = None, **kwargs: Any, ) -> Any: """Save a file-like object to GridFS using the given filename. @@ -207,6 +219,7 @@ def save_upload(filename): :param str content_type: the MIME content-type of the file. If ``None``, the content-type is guessed from the filename using :func:`~mimetypes.guess_type` + :param str db: the target database, if different from the default database. :param kwargs: extra attributes to be stored in the file's document, passed directly to :meth:`gridfs.GridFS.put` """ @@ -218,7 +231,11 @@ def save_upload(filename): if content_type is None: content_type, _ = guess_type(filename) - assert self.db is not None, "Please initialize the app before calling save_file!" - storage = GridFS(self.db, base) + if db: + db_obj = self.cx[db] + else: + db_obj = self.db + assert db_obj is not None, "Please initialize the app before calling save_file!" + storage = GridFS(db_obj, base) id = storage.put(fileobj, filename=filename, content_type=content_type, **kwargs) return id diff --git a/tests/test_gridfs.py b/tests/test_gridfs.py index 93a6672..ee6b6e3 100644 --- a/tests/test_gridfs.py +++ b/tests/test_gridfs.py @@ -30,6 +30,14 @@ def test_it_saves_files(self): gridfs = GridFS(self.mongo.db) assert gridfs.exists({"filename": "my-file"}) + def test_it_saves_files_to_another_db(self): + fileobj = BytesIO(b"these are the bytes") + + self.mongo.save_file("my-file", fileobj, db="other") + assert self.mongo.db is not None + gridfs = GridFS(self.mongo.cx["other"]) + assert gridfs.exists({"filename": "my-file"}) + def test_it_saves_files_with_props(self): fileobj = BytesIO(b"these are the bytes") @@ -56,6 +64,7 @@ def setUp(self): # make it bigger than 1 gridfs chunk self.myfile = BytesIO(b"a" * 500 * 1024) self.mongo.save_file("myfile.txt", self.myfile) + self.mongo.save_file("my_other_file.txt", self.myfile, db="other") def test_it_404s_for_missing_files(self): with pytest.raises(NotFound): @@ -65,6 +74,10 @@ def test_it_sets_content_type(self): resp = self.mongo.send_file("myfile.txt") assert resp.content_type.startswith("text/plain") + def test_it_sends_file_to_another_db(self): + resp = self.mongo.send_file("my_other_file.txt", db="other") + assert resp.content_type.startswith("text/plain") + def test_it_sets_content_length(self): resp = self.mongo.send_file("myfile.txt") assert resp.content_length == len(self.myfile.getvalue())