From 2c01a6c043171f639ccfecb410f69a7d1077616d Mon Sep 17 00:00:00 2001
From: Eran Zimbler <eranz@rumble.me>
Date: Wed, 25 Feb 2015 17:26:26 +0200
Subject: [PATCH] adding _replica_connection option and option

---
 sleepymongoose/handlers.py | 88 +++++++++++++++++++++++++++++---------
 sleepymongoose/httpd.py    | 10 +++--
 2 files changed, 75 insertions(+), 23 deletions(-)

diff --git a/sleepymongoose/handlers.py b/sleepymongoose/handlers.py
index 0e6a9b0..fb8712c 100644
--- a/sleepymongoose/handlers.py
+++ b/sleepymongoose/handlers.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 
 from bson.son import SON
-from pymongo import Connection, ASCENDING, DESCENDING
+from pymongo import Connection, ASCENDING, DESCENDING, ReplicaSetConnection
 from pymongo.errors import ConnectionFailure, ConfigurationError, OperationFailure, AutoReconnect
 from bson import json_util
 
@@ -28,33 +28,41 @@ class MongoHandler:
 
     _cursor_id = 0
 
-    def __init__(self, mongos):
+    def __init__(self, mongos, replica=None):
         self.connections = {}
+        if replica:
+                args = MongoFakeFieldStorage({"server" : ",".join(mongos), "replica": replica })
+                out = MongoFakeStream()
+                self._connect(args, out.ostream)
+        else:
+            for host in mongos:
+                args = MongoFakeFieldStorage({"server" : host })
 
-        for host in mongos:
-            args = MongoFakeFieldStorage({"server" : host})
-
-            out = MongoFakeStream()
-            if len(mongos) == 1:
-                name = "default"
-            else:
-                name = host.replace(".", "") 
-                name = name.replace(":", "")
-
-            self._connect(args, out.ostream, name = name)
+                out = MongoFakeStream()
+                if len(mongos) == 1:
+                    name = "default"
+                else:
+                    name = host.replace(".", "") 
+                    name = name.replace(":", "")
+                self._connect(args, out.ostream, name = name)
         
-    def _get_connection(self, name = None, uri='mongodb://localhost:27017'):
+    def _get_connection(self, name = None, uri='mongodb://localhost:27017', replica=None):
         if name == None:
             name = "default"
 
         if name in self.connections:
             return self.connections[name]
-        
-        try:
-            connection = Connection(uri, network_timeout = 2)
-        except (ConnectionFailure, ConfigurationError):
-            return None
-
+        if not replica:
+            try:
+                connection = Connection(uri, network_timeout = 2)
+            except (ConnectionFailure, ConfigurationError):
+                return None
+        else:
+            try:
+                connection = ReplicaSetConnection(uri, replicaSet=replica, network_timeout = 2)
+                print connection
+            except (ConnectionFailure, ConfigurationError):
+                return None
         self.connections[name] = connection
         return connection
 
@@ -145,6 +153,46 @@ def _status(self, args, out, name = None, db = None, collection = None):
 
         out(json.dumps(result))
     
+    def _replica_connection(self, args, out, name = None, db = None, collection = None):
+        """
+        connect to a replica set
+        """
+
+        if type(args).__name__ == 'dict':
+            out('{"ok" : 0, "errmsg" : "_connect must be a POST request"}')
+            return
+        if "server" in args:
+            try:
+                uri = args.getvalue('server')
+            except Exception, e:
+                print uri
+                print e
+                out('{"ok" : 0, "errmsg" : "invalid server uri given", "server" : "%s"}' % uri)
+                return
+        else:
+            uri = 'mongodb://localhost:27017'
+
+        if name == None:
+            name = "default"
+
+        if "replica" in args:
+            try:
+                replica = args.getvalue('replica')
+            except Exception, e:
+                print replica
+                print e
+                out('{"ok" : 0, "errmsg" : "missing replica id", "replica" : "%s"}' % replica)
+                return
+        else:
+            out('{"ok" : 0, "errmsg" : "missing replica id"}')
+            return
+        conn = self._get_connection(name, uri, replica=replica)
+        if conn != None:
+            out('{"ok" : 1, "server" : "%s", "name" : "%s"}' % (uri, name))
+        else:
+            out('{"ok" : 0, "errmsg" : "could not connect", "server" : "%s", "name" : "%s"}' % (uri, name))
+
+
     def _connect(self, args, out, name = None, db = None, collection = None):
         """
         connect to a mongod
diff --git a/sleepymongoose/httpd.py b/sleepymongoose/httpd.py
index 53629d5..a937320 100644
--- a/sleepymongoose/httpd.py
+++ b/sleepymongoose/httpd.py
@@ -72,6 +72,7 @@ class MongoHTTPRequest(BaseHTTPRequestHandler):
 
     docroot = "."
     mongos = []
+    replica = None
     response_headers = []
     jsonp_callback = None;
 
@@ -235,7 +236,7 @@ def serve_forever(port):
             print "--------Secure Connection--------\n"
             server = MongoServer(('', port), MongoHTTPSRequest)
 
-        MongoHandler.mh = MongoHandler(MongoHTTPRequest.mongos)
+        MongoHandler.mh = MongoHandler(MongoHTTPRequest.mongos,replica= MongoHTTPRequest.replica)
         
         print "listening for connections on http://localhost:27080\n"
         try:
@@ -259,12 +260,13 @@ def usage():
     print "\t-d|--docroot\tlocation from which to load files"
     print "\t-s|--secure\tlocation of .pem file if ssl is desired"
     print "\t-m|--mongos\tcomma-separated list of mongo servers to connect to"
+    print "\t-r|--replica\treplica set name"
 
 
 def main():
     try:
-        opts, args = getopt.getopt(sys.argv[1:], "xd:s:m:", ["xorigin", "docroot=",
-            "secure=", "mongos="])
+        opts, args = getopt.getopt(sys.argv[1:], "xd:s:m:r:", ["xorigin", "docroot=",
+            "secure=", "mongos=", "replica="])
 
         for o, a in opts:
             if o == "-d" or o == "--docroot":
@@ -273,6 +275,8 @@ def main():
                 MongoHTTPRequest.docroot = a
             if o == "-s" or o == "--secure":
                 MongoServer.pem = a
+            if o == "-r" or o == "--replica":
+                MongoHTTPRequest.replica = a
             if o == "-m" or o == "--mongos":
                 MongoHTTPRequest.mongos = a.split(',')
             if o == "-x" or o == "--xorigin":