@@ -28,6 +28,10 @@ def sample_data(collection):
28
28
for x in range (1 , 6 ):
29
29
collection .add ({"x" : x })
30
30
31
+ subcollection_doc = collection .document ("subcollection" )
32
+ subcollection_doc .set ({})
33
+ subcollection_doc .collection ("subcollection1" ).add ({})
34
+
31
35
32
36
# ===== Query =====
33
37
@@ -111,3 +115,80 @@ def _test():
111
115
def test_firestore_aggregation_query_generators (collection , assert_trace_for_generator ):
112
116
aggregation_query = collection .select ("x" ).where (field_path = "x" , op_string = "<=" , value = 3 ).count ()
113
117
assert_trace_for_generator (aggregation_query .stream )
118
+
119
+
120
+ # ===== CollectionGroup =====
121
+
122
+
123
+ @pytest .fixture ()
124
+ def patch_partition_queries (monkeypatch , client , collection , sample_data ):
125
+ """
126
+ Partitioning is not implemented in the Firestore emulator.
127
+
128
+ Ordinarily this method would return a generator of Cursor objects. Each Cursor must point at a valid document path.
129
+ To test this, we can patch the RPC to return 1 Cursor which is pointed at any document available.
130
+ The get_partitions will take that and make 2 QueryPartition objects out of it, which should be enough to ensure
131
+ we can exercise the generator's tracing.
132
+ """
133
+ from google .cloud .firestore_v1 .types .query import Cursor
134
+ from google .cloud .firestore_v1 .types .document import Value
135
+
136
+ subcollection = collection .document ("subcollection" ).collection ("subcollection1" )
137
+ documents = [d for d in subcollection .list_documents ()]
138
+
139
+ def mock_partition_query (* args , ** kwargs ):
140
+ yield Cursor (before = False , values = [Value (reference_value = documents [0 ].path )])
141
+
142
+ monkeypatch .setattr (client ._firestore_api , "partition_query" , mock_partition_query )
143
+ yield
144
+
145
+
146
+ def _exercise_collection_group (collection ):
147
+ from google .cloud .firestore import CollectionGroup
148
+
149
+ collection_group = CollectionGroup (collection )
150
+ assert len (collection_group .get ())
151
+ assert len ([d for d in collection_group .stream ()])
152
+
153
+ partitions = [p for p in collection_group .get_partitions (1 )]
154
+ assert len (partitions ) == 2
155
+ documents = []
156
+ while partitions :
157
+ documents .extend (partitions .pop ().query ().get ())
158
+ assert len (documents ) == 6
159
+
160
+
161
+ def test_firestore_collection_group (collection , patch_partition_queries ):
162
+ _test_scoped_metrics = [
163
+ ("Datastore/statement/Firestore/%s/get" % collection .id , 3 ),
164
+ ("Datastore/statement/Firestore/%s/stream" % collection .id , 1 ),
165
+ ("Datastore/statement/Firestore/%s/get_partitions" % collection .id , 1 ),
166
+ ]
167
+
168
+ _test_rollup_metrics = [
169
+ ("Datastore/operation/Firestore/get" , 3 ),
170
+ ("Datastore/operation/Firestore/stream" , 1 ),
171
+ ("Datastore/operation/Firestore/get_partitions" , 1 ),
172
+ ("Datastore/all" , 5 ),
173
+ ("Datastore/allOther" , 5 ),
174
+ ]
175
+
176
+ @validate_database_duration ()
177
+ @validate_transaction_metrics (
178
+ "test_firestore_collection_group" ,
179
+ scoped_metrics = _test_scoped_metrics ,
180
+ rollup_metrics = _test_rollup_metrics ,
181
+ background_task = True ,
182
+ )
183
+ @background_task (name = "test_firestore_collection_group" )
184
+ def _test ():
185
+ _exercise_collection_group (collection )
186
+
187
+ _test ()
188
+
189
+
190
+ @background_task ()
191
+ def test_firestore_collection_group_generators (collection , assert_trace_for_generator , patch_partition_queries ):
192
+ from google .cloud .firestore import CollectionGroup
193
+ collection_group = CollectionGroup (collection )
194
+ assert_trace_for_generator (collection_group .get_partitions , 1 )
0 commit comments