diff --git a/.env b/.env index 53d796bc1edb..be710e0ee3ad 100644 --- a/.env +++ b/.env @@ -100,3 +100,7 @@ # # Time in duration format (e.g. 1h30m) after which a backend is considered busy # LOCALAI_WATCHDOG_BUSY_TIMEOUT=5m + +# allowed access ip config, ie: 192.168.1.0/24,10.0.0.1,127.0.0.1 +# export LOCALAI_IP_ALLOWLIST="192.168.1.0/24,10.0.0.1,127.0.0.1" +# LOCALAI_IP_ALLOWLIST=192.168.1.0/24 diff --git a/core/cli/run.go b/core/cli/run.go index 999e05d29bdc..9e978da5da91 100644 --- a/core/cli/run.go +++ b/core/cli/run.go @@ -46,6 +46,7 @@ type RunCMD struct { ContextSize int `env:"LOCALAI_CONTEXT_SIZE,CONTEXT_SIZE" help:"Default context size for models" group:"performance"` Address string `env:"LOCALAI_ADDRESS,ADDRESS" default:":8080" help:"Bind address for the API server" group:"api"` + IpAllowList string `env:"LOCALAI_IP_ALLOWLIST,IP_ALLOWLIST" help:"A list of IP addresses or CIDR ranges to allow access" group:"api"` CORS bool `env:"LOCALAI_CORS,CORS" help:"" group:"api"` CORSAllowOrigins string `env:"LOCALAI_CORS_ALLOW_ORIGINS,CORS_ALLOW_ORIGINS" group:"api"` CSRF bool `env:"LOCALAI_CSRF" help:"Enables fiber CSRF middleware" group:"api"` @@ -127,6 +128,7 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error { config.WithP2PNetworkID(r.Peer2PeerNetworkID), config.WithLoadToMemory(r.LoadToMemory), config.WithMachineTag(r.MachineTag), + config.WithIPAllowList(r.IpAllowList), } if r.DisableMetricsEndpoint { diff --git a/core/config/application_config.go b/core/config/application_config.go index 775e30f66034..d2d6c5ffa060 100644 --- a/core/config/application_config.go +++ b/core/config/application_config.go @@ -6,6 +6,7 @@ import ( "regexp" "time" + "github.com/mudler/LocalAI/core/http/utils" "github.com/mudler/LocalAI/pkg/system" "github.com/mudler/LocalAI/pkg/xsysinfo" "github.com/rs/zerolog/log" @@ -63,6 +64,11 @@ type ApplicationConfig struct { WatchDogBusyTimeout, WatchDogIdleTimeout time.Duration MachineTag string + + // ie: 192.168.1.0/24,10.0.0.1,127.0.0.1 + IpAllowList string + + IPAllowListHelper *utils.IPAllowList } type AppOption func(*ApplicationConfig) @@ -128,6 +134,15 @@ func WithP2PToken(s string) AppOption { } } +func WithIPAllowList(s string) AppOption { + return func(o *ApplicationConfig) { + log.Info().Msgf("Application IpAllowList($LOCALAI_IP_ALLOWLIST): %s", s) + o.IpAllowList = s + var ipAllowListHelper, _ = utils.NewIPAllowList(s) + o.IPAllowListHelper = ipAllowListHelper + } +} + var EnableWatchDog = func(o *ApplicationConfig) { o.WatchDog = true } diff --git a/core/http/app.go b/core/http/app.go index 09f06883451a..a3c365f87487 100644 --- a/core/http/app.go +++ b/core/http/app.go @@ -128,6 +128,17 @@ func API(application *application.Application) (*fiber.App, error) { router.Use(recover.New()) } + //IP restriction + router.Use(func(c *fiber.Ctx) error { + clientIP := c.IP() + if application.ApplicationConfig().IPAllowListHelper.IsAllowed(clientIP) { + return c.Next() + } + return c.Status(fiber.StatusForbidden).JSON(fiber.Map{ + "error": "Forbidden: your IP is not allowed", + }) + }) + if !application.ApplicationConfig().DisableMetrics { metricsService, err := services.NewLocalAIMetricsService() if err != nil { diff --git a/core/http/utils/IPAllowList.go b/core/http/utils/IPAllowList.go new file mode 100644 index 000000000000..0606b237f520 --- /dev/null +++ b/core/http/utils/IPAllowList.go @@ -0,0 +1,100 @@ +package utils + +import ( + "fmt" + "net" + "net/netip" + "strings" + "sync" +) + +type IPAllowList struct { + allowList string + cidrs []*net.IPNet + ips []net.IP + mu sync.RWMutex + enabled bool +} + +func NewIPAllowList(allowList string) (*IPAllowList, error) { + + w := &IPAllowList{} + err := w.Update(allowList) + return w, err +} + +func (w *IPAllowList) GetAllowList() string { + return w.allowList +} + +func (w *IPAllowList) Update(allowListStr string) error { + var cidrs []*net.IPNet + var ips []net.IP + + allowList := make([]string, 0) + if allowListStr != "" { + allowList = strings.Split(allowListStr, ",") + } + + for _, item := range allowList { + _, cidrNet, err := net.ParseCIDR(item) + if err == nil { + cidrs = append(cidrs, cidrNet) + } else { + ip := net.ParseIP(item) + if ip != nil { + ips = append(ips, ip) + } else { + return fmt.Errorf("invalid allowList item: %s", item) + } + } + } + + w.mu.Lock() + defer w.mu.Unlock() + w.allowList = allowListStr + w.cidrs = cidrs + w.ips = ips + w.enabled = len(cidrs) > 0 || len(ips) > 0 + return nil +} + +func (w *IPAllowList) IsAllowed(ip interface{}) bool { + if !w.enabled { + return true + } + + var parsedIP net.IP + switch v := ip.(type) { + case string: + parsedIP = net.ParseIP(v) + case net.IP: + parsedIP = v + case netip.Addr: + parsedIP = net.IP(v.AsSlice()) + default: + if str, ok := v.(string); ok { + parsedIP = net.ParseIP(str) + } + } + + if parsedIP == nil { + return false + } + + w.mu.RLock() + defer w.mu.RUnlock() + + for _, cidr := range w.cidrs { + if cidr.Contains(parsedIP) { + return true + } + } + + for _, allowedIP := range w.ips { + if parsedIP.Equal(allowedIP) { + return true + } + } + return false +} diff --git a/core/http/utils/IPAllowList_test.go b/core/http/utils/IPAllowList_test.go new file mode 100644 index 000000000000..24cff127f4ba --- /dev/null +++ b/core/http/utils/IPAllowList_test.go @@ -0,0 +1,44 @@ +package utils + +import ( + "fmt" + "testing" +) + +func TestIPAllowList(t *testing.T) { + // Test empty AllowList (no restrictions) + w, err := NewIPAllowList("") + + if err != nil { + t.Fatalf("Expected no error for empty AllowList, got: %v", err) + } + if !w.IsAllowed("192.168.1.100") { + t.Error("Empty AllowList should allow all IPs") + } + + // Test valid AllowList + AllowList := "192.168.1.0/24,10.0.0.1,127.0.0.1" + w, err = NewIPAllowList(AllowList) + if err != nil { + t.Fatalf("Failed to create IP AllowList: %v", err) + } + + tests := []struct { + ip string + expected bool + }{ + {"192.168.1.100", true}, + {"10.0.0.1", true}, + {"127.0.0.1", true}, + {"10.0.0.2", false}, + {"172.16.0.1", false}, + } + + for _, tc := range tests { + t.Run(fmt.Sprintf("IP: %s", tc.ip), func(t *testing.T) { + if got := w.IsAllowed(tc.ip); got != tc.expected { + t.Errorf("isAllowedIP(%q) = %v, want %v", tc.ip, got, tc.expected) + } + }) + } +}