Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .env
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,7 @@
#
# Time in duration format (e.g. 1h30m) after which a backend is considered busy
# LOCALAI_WATCHDOG_BUSY_TIMEOUT=5m

# allowed access ip config, ie: 192.168.1.0/24,10.0.0.1,127.0.0.1
# export LOCALAI_IP_ALLOWLIST="192.168.1.0/24,10.0.0.1,127.0.0.1"
# LOCALAI_IP_ALLOWLIST=192.168.1.0/24
2 changes: 2 additions & 0 deletions core/cli/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ type RunCMD struct {
ContextSize int `env:"LOCALAI_CONTEXT_SIZE,CONTEXT_SIZE" help:"Default context size for models" group:"performance"`

Address string `env:"LOCALAI_ADDRESS,ADDRESS" default:":8080" help:"Bind address for the API server" group:"api"`
IpAllowList string `env:"LOCALAI_IP_ALLOWLIST,IP_ALLOWLIST" help:"A list of IP addresses or CIDR ranges to allow access" group:"api"`
CORS bool `env:"LOCALAI_CORS,CORS" help:"" group:"api"`
CORSAllowOrigins string `env:"LOCALAI_CORS_ALLOW_ORIGINS,CORS_ALLOW_ORIGINS" group:"api"`
CSRF bool `env:"LOCALAI_CSRF" help:"Enables fiber CSRF middleware" group:"api"`
Expand Down Expand Up @@ -127,6 +128,7 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
config.WithP2PNetworkID(r.Peer2PeerNetworkID),
config.WithLoadToMemory(r.LoadToMemory),
config.WithMachineTag(r.MachineTag),
config.WithIPAllowList(r.IpAllowList),
}

if r.DisableMetricsEndpoint {
Expand Down
15 changes: 15 additions & 0 deletions core/config/application_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"regexp"
"time"

"github.com/mudler/LocalAI/core/http/utils"
"github.com/mudler/LocalAI/pkg/system"
"github.com/mudler/LocalAI/pkg/xsysinfo"
"github.com/rs/zerolog/log"
Expand Down Expand Up @@ -63,6 +64,11 @@ type ApplicationConfig struct {
WatchDogBusyTimeout, WatchDogIdleTimeout time.Duration

MachineTag string

// ie: 192.168.1.0/24,10.0.0.1,127.0.0.1
IpAllowList string

IPAllowListHelper *utils.IPAllowList
}

type AppOption func(*ApplicationConfig)
Expand Down Expand Up @@ -128,6 +134,15 @@ func WithP2PToken(s string) AppOption {
}
}

func WithIPAllowList(s string) AppOption {
return func(o *ApplicationConfig) {
log.Info().Msgf("Application IpAllowList($LOCALAI_IP_ALLOWLIST): %s", s)
o.IpAllowList = s
var ipAllowListHelper, _ = utils.NewIPAllowList(s)
o.IPAllowListHelper = ipAllowListHelper
}
}

var EnableWatchDog = func(o *ApplicationConfig) {
o.WatchDog = true
}
Expand Down
11 changes: 11 additions & 0 deletions core/http/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,17 @@ func API(application *application.Application) (*fiber.App, error) {
router.Use(recover.New())
}

//IP restriction
router.Use(func(c *fiber.Ctx) error {
clientIP := c.IP()
if application.ApplicationConfig().IPAllowListHelper.IsAllowed(clientIP) {
return c.Next()
}
return c.Status(fiber.StatusForbidden).JSON(fiber.Map{
"error": "Forbidden: your IP is not allowed",
})
})

if !application.ApplicationConfig().DisableMetrics {
metricsService, err := services.NewLocalAIMetricsService()
if err != nil {
Expand Down
100 changes: 100 additions & 0 deletions core/http/utils/IPAllowList.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package utils
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks better placed as ipallowlist.go file inside core/http rather than having it's own utils package.


import (
"fmt"
"net"
"net/netip"
"strings"
"sync"
)

type IPAllowList struct {
allowList string
cidrs []*net.IPNet
ips []net.IP
mu sync.RWMutex
enabled bool
}

func NewIPAllowList(allowList string) (*IPAllowList, error) {

w := &IPAllowList{}
err := w.Update(allowList)
return w, err
}

func (w *IPAllowList) GetAllowList() string {
return w.allowList
}

func (w *IPAllowList) Update(allowListStr string) error {
var cidrs []*net.IPNet
var ips []net.IP

allowList := make([]string, 0)
if allowListStr != "" {
allowList = strings.Split(allowListStr, ",")
}

for _, item := range allowList {
_, cidrNet, err := net.ParseCIDR(item)
if err == nil {
cidrs = append(cidrs, cidrNet)
} else {
ip := net.ParseIP(item)
if ip != nil {
ips = append(ips, ip)
} else {
return fmt.Errorf("invalid allowList item: %s", item)
}
}
}

w.mu.Lock()
defer w.mu.Unlock()
w.allowList = allowListStr
w.cidrs = cidrs
w.ips = ips
w.enabled = len(cidrs) > 0 || len(ips) > 0
return nil
}

func (w *IPAllowList) IsAllowed(ip interface{}) bool {
if !w.enabled {
return true
}

var parsedIP net.IP
switch v := ip.(type) {
case string:
parsedIP = net.ParseIP(v)
case net.IP:
parsedIP = v
case netip.Addr:
parsedIP = net.IP(v.AsSlice())
default:
if str, ok := v.(string); ok {
parsedIP = net.ParseIP(str)
}
}

if parsedIP == nil {
return false
}

w.mu.RLock()
defer w.mu.RUnlock()

for _, cidr := range w.cidrs {
if cidr.Contains(parsedIP) {
return true
}
}

for _, allowedIP := range w.ips {
if parsedIP.Equal(allowedIP) {
return true
}
}
return false
}
44 changes: 44 additions & 0 deletions core/http/utils/IPAllowList_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package utils

import (
"fmt"
"testing"
)

func TestIPAllowList(t *testing.T) {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better be consistent with all the other tests in this repository and use ginkgo

// Test empty AllowList (no restrictions)
w, err := NewIPAllowList("")

if err != nil {
t.Fatalf("Expected no error for empty AllowList, got: %v", err)
}
if !w.IsAllowed("192.168.1.100") {
t.Error("Empty AllowList should allow all IPs")
}

// Test valid AllowList
AllowList := "192.168.1.0/24,10.0.0.1,127.0.0.1"
w, err = NewIPAllowList(AllowList)
if err != nil {
t.Fatalf("Failed to create IP AllowList: %v", err)
}

tests := []struct {
ip string
expected bool
}{
{"192.168.1.100", true},
{"10.0.0.1", true},
{"127.0.0.1", true},
{"10.0.0.2", false},
{"172.16.0.1", false},
}

for _, tc := range tests {
t.Run(fmt.Sprintf("IP: %s", tc.ip), func(t *testing.T) {
if got := w.IsAllowed(tc.ip); got != tc.expected {
t.Errorf("isAllowedIP(%q) = %v, want %v", tc.ip, got, tc.expected)
}
})
}
}
Loading