Skip to content

Commit b381416

Browse files
authored
Merge pull request #957 from DarthMax/implement_missing_write_tests
GDSA-146 Add missing write tests
2 parents 4d1e3fd + 6c0e4b4 commit b381416

40 files changed

+1047
-290
lines changed

graphdatascience/procedure_surface/arrow/articlerank_arrow_endpoints.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def write(
172172
)
173173

174174
result = self._node_property_endpoints.run_job_and_write(
175-
"v2/centrality.articleRank", G, config, write_concurrency, concurrency
175+
"v2/centrality.articleRank", G, config, write_concurrency, concurrency, write_property
176176
)
177177

178178
return ArticleRankWriteResult(**result)

graphdatascience/procedure_surface/arrow/articulationpoints_arrow_endpoints.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def write(
123123
)
124124

125125
result = self._node_property_endpoints.run_job_and_write(
126-
"v2/centrality.articulationPoints", G, config, write_concurrency, concurrency
126+
"v2/centrality.articulationPoints", G, config, write_concurrency, concurrency, write_property
127127
)
128128

129129
return ArticulationPointsWriteResult(**result)

graphdatascience/procedure_surface/arrow/betweenness_arrow_endpoints.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def write(
154154
)
155155

156156
result = self._node_property_endpoints.run_job_and_write(
157-
"v2/centrality.betweenness", G, config, write_concurrency, concurrency
157+
"v2/centrality.betweenness", G, config, write_concurrency, concurrency, write_property
158158
)
159159

160160
return BetweennessWriteResult(**result)

graphdatascience/procedure_surface/arrow/celf_arrow_endpoints.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,12 @@ def write(
147147
)
148148

149149
result = self._node_property_endpoints.run_job_and_write(
150-
"v2/centrality.celf", G, config, write_concurrency=write_concurrency, concurrency=concurrency
150+
"v2/centrality.celf",
151+
G,
152+
config,
153+
write_concurrency=write_concurrency,
154+
concurrency=concurrency,
155+
property_overwrites=write_property,
151156
)
152157

153158
return CelfWriteResult(**result)

graphdatascience/procedure_surface/arrow/closeness_arrow_endpoints.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def write(
134134
)
135135

136136
result = self._node_property_endpoints.run_job_and_write(
137-
"v2/centrality.closeness", G, config, write_concurrency, concurrency
137+
"v2/centrality.closeness", G, config, write_concurrency, concurrency, write_property
138138
)
139139

140140
return ClosenessWriteResult(**result)

graphdatascience/procedure_surface/arrow/closeness_harmonic_arrow_endpoints.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def write(
124124
)
125125

126126
result = self._node_property_endpoints.run_job_and_write(
127-
"v2/centrality.harmonic", G, config, write_concurrency, concurrency
127+
"v2/centrality.harmonic", G, config, write_concurrency, concurrency, write_property
128128
)
129129

130130
return ClosenessHarmonicWriteResult(**result)

graphdatascience/procedure_surface/arrow/degree_arrow_endpoints.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def write(
131131
)
132132

133133
result = self._node_property_endpoints.run_job_and_write(
134-
"v2/centrality.degree", G, config, write_concurrency, concurrency
134+
"v2/centrality.degree", G, config, write_concurrency, concurrency, write_property
135135
)
136136

137137
return DegreeWriteResult(**result)

graphdatascience/procedure_surface/arrow/eigenvector_arrow_endpoints.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def write(
168168
)
169169

170170
result = self._node_property_endpoints.run_job_and_write(
171-
"v2/centrality.eigenvector", G, config, write_concurrency, concurrency
171+
"v2/centrality.eigenvector", G, config, write_concurrency, concurrency, write_property
172172
)
173173

174174
return EigenvectorWriteResult(**result)

graphdatascience/procedure_surface/arrow/graphsage_predict_arrow_endpoints.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,15 @@
1111
)
1212

1313
from ...arrow_client.authenticated_flight_client import AuthenticatedArrowClient
14+
from ...arrow_client.v2.remote_write_back_client import RemoteWriteBackClient
1415
from .model_api_arrow import ModelApiArrow
1516
from .node_property_endpoints import NodePropertyEndpoints
1617

1718

1819
class GraphSagePredictArrowEndpoints(GraphSagePredictEndpoints):
19-
def __init__(self, arrow_client: AuthenticatedArrowClient):
20+
def __init__(self, arrow_client: AuthenticatedArrowClient, write_back_client: Optional[RemoteWriteBackClient]):
2021
self._arrow_client = arrow_client
21-
self._node_property_endpoints = NodePropertyEndpoints(arrow_client)
22+
self._node_property_endpoints = NodePropertyEndpoints(arrow_client, write_back_client)
2223
self._model_api = ModelApiArrow(arrow_client)
2324

2425
def stream(
@@ -79,11 +80,7 @@ def write(
7980
)
8081

8182
raw_result = self._node_property_endpoints.run_job_and_write(
82-
"v2/embeddings.graphSage",
83-
G,
84-
config,
85-
write_concurrency,
86-
concurrency,
83+
"v2/embeddings.graphSage", G, config, write_concurrency, concurrency, write_property
8784
)
8885

8986
return GraphSageWriteResult(**raw_result)

graphdatascience/procedure_surface/arrow/graphsage_train_arrow_endpoints.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from graphdatascience.procedure_surface.arrow.graphsage_predict_arrow_endpoints import GraphSagePredictArrowEndpoints
66

77
from ...arrow_client.authenticated_flight_client import AuthenticatedArrowClient
8+
from ...arrow_client.v2.remote_write_back_client import RemoteWriteBackClient
89
from ..api.graphsage_train_endpoints import (
910
GraphSageTrainEndpoints,
1011
GraphSageTrainResult,
@@ -14,9 +15,10 @@
1415

1516

1617
class GraphSageTrainArrowEndpoints(GraphSageTrainEndpoints):
17-
def __init__(self, arrow_client: AuthenticatedArrowClient):
18+
def __init__(self, arrow_client: AuthenticatedArrowClient, write_back_client: Optional[RemoteWriteBackClient]):
1819
self._arrow_client = arrow_client
19-
self._node_property_endpoints = NodePropertyEndpoints(arrow_client)
20+
self._write_back_client = write_back_client
21+
self._node_property_endpoints = NodePropertyEndpoints(arrow_client, write_back_client=write_back_client)
2022
self._model_api = ModelApiArrow(arrow_client)
2123

2224
def train(
@@ -83,7 +85,9 @@ def train(
8385
result = self._node_property_endpoints.run_job_and_get_summary("v2/embeddings.graphSage.train", G, config)
8486

8587
model = GraphSageModelV2(
86-
model_name, self._model_api, predict_endpoints=GraphSagePredictArrowEndpoints(self._arrow_client)
88+
model_name,
89+
self._model_api,
90+
predict_endpoints=GraphSagePredictArrowEndpoints(self._arrow_client, self._write_back_client),
8791
)
8892
train_result = GraphSageTrainResult(**result)
8993

0 commit comments

Comments
 (0)