diff --git a/pkg/lib/aws/servicequotas.go b/pkg/lib/aws/servicequotas.go index 534058a59c..480e92db32 100644 --- a/pkg/lib/aws/servicequotas.go +++ b/pkg/lib/aws/servicequotas.go @@ -30,13 +30,6 @@ var _standardInstanceFamilies = strset.New("a", "c", "d", "h", "i", "m", "r", "t var _knownInstanceFamilies = strset.Union(_standardInstanceFamilies, strset.New("p", "g", "inf", "x", "f", "mac")) const ( - _elasticIPsQuotaCode = "L-0263D0A3" - _internetGatewayQuotaCode = "L-A4707A72" - _natGatewayQuotaCode = "L-FE5A380F" - _vpcQuotaCode = "L-F678F1CE" - _securityGroupsQuotaCode = "L-E79EC296" - _securityGroupRulesQuotaCode = "L-0EA8095F" - // 11 inbound rules _baseInboundRulesForNodeGroup = 11 _inboundRulesPerAZ = 8 @@ -159,181 +152,167 @@ func (c *Client) VerifyInstanceQuota(instances []InstanceTypeRequests) error { return nil } -func (c *Client) VerifyNetworkQuotas( - requiredInternetGateways int, - natGatewayRequired bool, - highlyAvailableNATGateway bool, - requiredVPCs int, - availabilityZones strset.Set, - numNodeGroups int, - longestCIDRWhiteList int) error { - quotaCodeToValueMap := map[string]int{ - _elasticIPsQuotaCode: 0, // elastic IP quota code - _internetGatewayQuotaCode: 0, // internet gw quota code - _natGatewayQuotaCode: 0, // nat gw quota code - _vpcQuotaCode: 0, // vpc quota code - _securityGroupsQuotaCode: 0, // security groups quota code - _securityGroupRulesQuotaCode: 0, // security group rules quota code - } - - err := c.ServiceQuotas().ListServiceQuotasPages( - &servicequotas.ListServiceQuotasInput{ - ServiceCode: aws.String("ec2"), - }, - func(page *servicequotas.ListServiceQuotasOutput, lastPage bool) bool { - if page == nil { - return false - } - for _, quota := range page.Quotas { - if quota == nil || quota.QuotaCode == nil || quota.Value == nil { - continue - } - if _, ok := quotaCodeToValueMap[*quota.QuotaCode]; ok { - quotaCodeToValueMap[*quota.QuotaCode] = int(*quota.Value) +func (c *Client) ListServiceQuotas(quotaCodes []string, serviceCodes []string) (map[string]int, error) { + desiredQuotaCodes := strset.New(quotaCodes...) + quotaCodeToValueMap := map[string]int{} + + for _, serviceCode := range serviceCodes { + err := c.ServiceQuotas().ListServiceQuotasPages( + &servicequotas.ListServiceQuotasInput{ + ServiceCode: aws.String(serviceCode), + }, + func(page *servicequotas.ListServiceQuotasOutput, lastPage bool) bool { + if page == nil { return false } - } - return true - }, - ) - if err != nil { - return errors.WithStack(err) + for _, quota := range page.Quotas { + if quota == nil || quota.QuotaCode == nil || quota.Value == nil { + continue + } + if desiredQuotaCodes.Has(*quota.QuotaCode) { + quotaCodeToValueMap[*quota.QuotaCode] = int(*quota.Value) + } + } + return true + }, + ) + if err != nil { + return nil, errors.Wrap(err, serviceCode) + } } - err = c.ServiceQuotas().ListServiceQuotasPages( - &servicequotas.ListServiceQuotasInput{ - ServiceCode: aws.String("vpc"), - }, - func(page *servicequotas.ListServiceQuotasOutput, lastPage bool) bool { - if page == nil { - return false - } - for _, quota := range page.Quotas { - if quota == nil || quota.QuotaCode == nil || quota.Value == nil { - continue - } - if _, ok := quotaCodeToValueMap[*quota.QuotaCode]; ok { - quotaCodeToValueMap[*quota.QuotaCode] = int(*quota.Value) - } - } - return true - }, - ) + return quotaCodeToValueMap, nil +} + +func (c *Client) VerifyInternetGatewayQuota(internetGatewayQuota int, requiredInternetGateways int) error { + internetGatewaysInUse, err := c.ListInternetGateways() if err != nil { - return errors.WithStack(err) + return err } - // check internet GW quota - if requiredInternetGateways > 0 { - internetGatewaysInUse, err := c.ListInternetGateways() - if err != nil { - return err - } - if quotaCodeToValueMap[_internetGatewayQuotaCode]-len(internetGatewaysInUse)-requiredInternetGateways < 0 { - additionalQuotaRequired := len(internetGatewaysInUse) + requiredInternetGateways - quotaCodeToValueMap[_internetGatewayQuotaCode] - return ErrorInternetGatewayLimitExceeded(quotaCodeToValueMap[_internetGatewayQuotaCode], additionalQuotaRequired, c.Region) - } + additionalQuotaRequired := len(internetGatewaysInUse) + requiredInternetGateways - internetGatewayQuota + + if additionalQuotaRequired > 0 { + return ErrorInternetGatewayLimitExceeded(internetGatewayQuota, additionalQuotaRequired, c.Region) } + return nil +} - if natGatewayRequired { - // get NAT GW in use per selected AZ - natGateways, err := c.DescribeNATGateways() - if err != nil { - return err - } - subnets, err := c.DescribeSubnets() - if err != nil { - return err +func (c *Client) VerifyNATGatewayQuota(natGatewayQuota int, availabilityZones strset.Set, highlyAvailableNATGateway bool) error { + // get NAT GW in use per selected AZ + natGateways, err := c.DescribeNATGateways() + if err != nil { + return err + } + subnets, err := c.DescribeSubnets() + if err != nil { + return err + } + azToGatewaysInUse := map[string]int{} + for _, natGateway := range natGateways { + if natGateway.SubnetId == nil { + continue } - azToGatewaysInUse := map[string]int{} - for _, natGateway := range natGateways { - if natGateway.SubnetId == nil { + for _, subnet := range subnets { + if subnet.SubnetId == nil || subnet.AvailabilityZone == nil { continue } - for _, subnet := range subnets { - if subnet.SubnetId == nil || subnet.AvailabilityZone == nil { - continue - } - if !availabilityZones.Has(*subnet.AvailabilityZone) { - continue - } - if *subnet.SubnetId == *natGateway.SubnetId { - azToGatewaysInUse[*subnet.AvailabilityZone]++ - } + if !availabilityZones.Has(*subnet.AvailabilityZone) { + continue } - } - // check NAT GW quota - numOfExhaustedNATGatewayAZs := 0 - azsWithQuotaDeficit := []string{} - for az, numActiveGatewaysOnAZ := range azToGatewaysInUse { - // -1 comes from the NAT gateway we require per AZ - azDeficit := quotaCodeToValueMap[_natGatewayQuotaCode] - numActiveGatewaysOnAZ - 1 - if azDeficit < 0 { - numOfExhaustedNATGatewayAZs++ - azsWithQuotaDeficit = append(azsWithQuotaDeficit, az) + if *subnet.SubnetId == *natGateway.SubnetId { + azToGatewaysInUse[*subnet.AvailabilityZone]++ } } - if (highlyAvailableNATGateway && numOfExhaustedNATGatewayAZs > 0) || (!highlyAvailableNATGateway && numOfExhaustedNATGatewayAZs == len(availabilityZones)) { - return ErrorNATGatewayLimitExceeded(quotaCodeToValueMap[_natGatewayQuotaCode], 1, azsWithQuotaDeficit, c.Region) + } + // check NAT GW quota + numOfExhaustedNATGatewayAZs := 0 + azsWithQuotaDeficit := []string{} + for az, numActiveGatewaysOnAZ := range azToGatewaysInUse { + // -1 comes from the NAT gateway we require per AZ + azDeficit := natGatewayQuota - numActiveGatewaysOnAZ - 1 + if azDeficit < 0 { + numOfExhaustedNATGatewayAZs++ + azsWithQuotaDeficit = append(azsWithQuotaDeficit, az) } } + if (highlyAvailableNATGateway && numOfExhaustedNATGatewayAZs > 0) || (!highlyAvailableNATGateway && numOfExhaustedNATGatewayAZs == len(availabilityZones)) { + return ErrorNATGatewayLimitExceeded(natGatewayQuota, 1, azsWithQuotaDeficit, c.Region) + } - // check EIP quota - if natGatewayRequired { - elasticIPsInUse, err := c.ListElasticIPs() - if err != nil { - return err - } - var requiredElasticIPs int - if highlyAvailableNATGateway { - requiredElasticIPs = len(availabilityZones) - } else { - requiredElasticIPs = 1 - } - if quotaCodeToValueMap[_elasticIPsQuotaCode]-len(elasticIPsInUse)-requiredElasticIPs < 0 { - additionalQuotaRequired := len(elasticIPsInUse) + requiredElasticIPs - quotaCodeToValueMap[_elasticIPsQuotaCode] - return ErrorEIPLimitExceeded(quotaCodeToValueMap[_elasticIPsQuotaCode], additionalQuotaRequired, c.Region) - } + return nil +} + +func (c *Client) VerifyEIPQuota(eipQuota int, availabilityZones strset.Set, highlyAvailableNATGateway bool) error { + elasticIPsInUse, err := c.ListElasticIPs() + if err != nil { + return err + } + var requiredElasticIPs int + if highlyAvailableNATGateway { + requiredElasticIPs = len(availabilityZones) + } else { + requiredElasticIPs = 1 } - // check VPC quota - if requiredVPCs > 0 { - vpcs, err := c.DescribeVpcs() - if err != nil { - return err - } - if quotaCodeToValueMap[_vpcQuotaCode]-len(vpcs)-requiredVPCs < 0 { - additionalQuotaRequired := len(vpcs) + requiredVPCs - quotaCodeToValueMap[_vpcQuotaCode] - return ErrorVPCLimitExceeded(quotaCodeToValueMap[_vpcQuotaCode], additionalQuotaRequired, c.Region) - } + additionalQuotaRequired := len(elasticIPsInUse) + requiredElasticIPs - eipQuota + + if additionalQuotaRequired > 0 { + return ErrorEIPLimitExceeded(eipQuota, additionalQuotaRequired, c.Region) } - // check rules quota for nodegroup SGs - requiredRulesForSG := requiredRulesForNodeGroupSecurityGroup(len(availabilityZones), longestCIDRWhiteList) - if requiredRulesForSG > quotaCodeToValueMap[_securityGroupRulesQuotaCode] { - additionalQuotaRequired := requiredRulesForSG - quotaCodeToValueMap[_securityGroupRulesQuotaCode] - return ErrorSecurityGroupRulesExceeded(quotaCodeToValueMap[_securityGroupRulesQuotaCode], additionalQuotaRequired, c.Region) + return nil +} + +func (c *Client) VerifyVPCQuota(vpcQuota int, requiredVPCs int) error { + vpcs, err := c.DescribeVpcs() + if err != nil { + return err } - // check rules quota for control plane SG - requiredRulesForCPSG := requiredRulesForControlPlaneSecurityGroup(numNodeGroups) - if requiredRulesForCPSG > quotaCodeToValueMap[_securityGroupRulesQuotaCode] { - additionalQuotaRequired := requiredRulesForCPSG - quotaCodeToValueMap[_securityGroupRulesQuotaCode] - return ErrorSecurityGroupRulesExceeded(quotaCodeToValueMap[_securityGroupRulesQuotaCode], additionalQuotaRequired, c.Region) + additionalQuotaRequired := len(vpcs) + requiredVPCs - vpcQuota + + if additionalQuotaRequired > 0 { + return ErrorVPCLimitExceeded(vpcQuota, additionalQuotaRequired, c.Region) } + return nil +} - // check security groups quota +func (c *Client) VerifySecurityGroupQuota(securifyGroupsQuota int, numNodeGroups int) error { requiredSecurityGroups := requiredSecurityGroups(numNodeGroups) sgs, err := c.DescribeSecurityGroups() if err != nil { return err } - if quotaCodeToValueMap[_securityGroupsQuotaCode]-len(sgs)-requiredSecurityGroups < 0 { - additionalQuotaRequired := len(sgs) + requiredSecurityGroups - quotaCodeToValueMap[_securityGroupsQuotaCode] - return ErrorSecurityGroupLimitExceeded(quotaCodeToValueMap[_securityGroupsQuotaCode], additionalQuotaRequired, c.Region) + additionalQuotaRequired := len(sgs) + requiredSecurityGroups - securifyGroupsQuota + + if additionalQuotaRequired > 0 { + return ErrorSecurityGroupLimitExceeded(securifyGroupsQuota, additionalQuotaRequired, c.Region) + + } + return nil +} + +func (c *Client) VerifySecurityGroupRulesQuota( + securifyGroupRulesQuota int, + availabilityZones strset.Set, + numNodeGroups int, + longestCIDRWhiteList int) error { + + // check rules quota for nodegroup SGs + requiredRulesForSG := requiredRulesForNodeGroupSecurityGroup(len(availabilityZones), longestCIDRWhiteList) + if requiredRulesForSG > securifyGroupRulesQuota { + additionalQuotaRequired := requiredRulesForSG - securifyGroupRulesQuota + return ErrorSecurityGroupRulesExceeded(securifyGroupRulesQuota, additionalQuotaRequired, c.Region) } + // check rules quota for control plane SG + requiredRulesForCPSG := requiredRulesForControlPlaneSecurityGroup(numNodeGroups) + if requiredRulesForCPSG > securifyGroupRulesQuota { + additionalQuotaRequired := requiredRulesForCPSG - securifyGroupRulesQuota + return ErrorSecurityGroupRulesExceeded(securifyGroupRulesQuota, additionalQuotaRequired, c.Region) + } return nil } diff --git a/pkg/types/clusterconfig/cluster_config.go b/pkg/types/clusterconfig/cluster_config.go index 37b3a7600b..d82d2fa627 100644 --- a/pkg/types/clusterconfig/cluster_config.go +++ b/pkg/types/clusterconfig/cluster_config.go @@ -1000,7 +1000,7 @@ func (cc *Config) Validate(awsClient *aws.Client) error { requiredVPCs = 1 } longestCIDRWhiteList := libmath.MaxInt(len(cc.APILoadBalancerCIDRWhiteList), len(cc.OperatorLoadBalancerCIDRWhiteList)) - if err := awsClient.VerifyNetworkQuotas(1, cc.NATGateway != NoneNATGateway, cc.NATGateway == HighlyAvailableNATGateway, requiredVPCs, strset.FromSlice(cc.AvailabilityZones), len(cc.NodeGroups), longestCIDRWhiteList); err != nil { + if err := VerifyNetworkQuotas(awsClient, 1, cc.NATGateway != NoneNATGateway, cc.NATGateway == HighlyAvailableNATGateway, requiredVPCs, strset.FromSlice(cc.AvailabilityZones), len(cc.NodeGroups), longestCIDRWhiteList); err != nil { // Skip AWS errors, since some regions (e.g. eu-north-1) do not support this API if !aws.IsAWSError(err) { return err diff --git a/pkg/types/clusterconfig/network_validations.go b/pkg/types/clusterconfig/network_validations.go new file mode 100644 index 0000000000..70ed4e3f04 --- /dev/null +++ b/pkg/types/clusterconfig/network_validations.go @@ -0,0 +1,134 @@ +/* +Copyright 2021 Cortex Labs, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package clusterconfig + +import ( + "fmt" + "strings" + + "github.com/cortexlabs/cortex/pkg/lib/aws" + "github.com/cortexlabs/cortex/pkg/lib/sets/strset" +) + +const ( + _elasticIPsQuotaCode = "L-0263D0A3" + _internetGatewayQuotaCode = "L-A4707A72" + _natGatewayQuotaCode = "L-FE5A380F" + _vpcQuotaCode = "L-F678F1CE" + _securityGroupsQuotaCode = "L-E79EC296" + _securityGroupRulesQuotaCode = "L-0EA8095F" +) + +func VerifyNetworkQuotas( + awsClient *aws.Client, + requiredInternetGateways int, + natGatewayRequired bool, + highlyAvailableNATGateway bool, + requiredVPCs int, + availabilityZones strset.Set, + numNodeGroups int, + longestCIDRWhiteList int) error { + + desiredQuotaCodes := []string{ + _elasticIPsQuotaCode, + _internetGatewayQuotaCode, + _natGatewayQuotaCode, + _vpcQuotaCode, + _securityGroupsQuotaCode, + _securityGroupRulesQuotaCode, + } + + serviceCodes := []string{"ec2", "vpc"} + + quotaCodeToValueMap, err := awsClient.ListServiceQuotas(desiredQuotaCodes, serviceCodes) + if err != nil { + return err + } + + var skippedValidations []string + defer func() { + if len(skippedValidations) > 0 { + fmt.Println(strings.Join(skippedValidations, "\n")) + } + }() + + // check internet GW quota + if requiredInternetGateways > 0 { + if internetGatewayQuota, found := quotaCodeToValueMap[_internetGatewayQuotaCode]; found { + err := awsClient.VerifyInternetGatewayQuota(internetGatewayQuota, requiredInternetGateways) + if err != nil { + return err + } + } else { + skippedValidations = append(skippedValidations, fmt.Sprintf("skipping internet gateway quota verification: unable to find internet gateway quota (%s)", _internetGatewayQuotaCode)) + } + } + + if natGatewayRequired { + if natGatewayQuota, found := quotaCodeToValueMap[_natGatewayQuotaCode]; found { + err := awsClient.VerifyNATGatewayQuota(natGatewayQuota, availabilityZones, highlyAvailableNATGateway) + if err != nil { + return err + } + } else { + skippedValidations = append(skippedValidations, fmt.Sprintf("skipping nat gateway quota verification: unable to find nat gateway quota (%s)\n", _natGatewayQuotaCode)) + } + } + + // check EIP quota + if natGatewayRequired { + if eipQuota, found := quotaCodeToValueMap[_elasticIPsQuotaCode]; found { + err := awsClient.VerifyEIPQuota(eipQuota, availabilityZones, highlyAvailableNATGateway) + if err != nil { + return err + } + } else { + skippedValidations = append(skippedValidations, fmt.Sprintf("skipping elastic ip quota verification: unable to find elastic ip quota (%s)\n", _elasticIPsQuotaCode)) + } + } + + if requiredVPCs > 0 { + if vpcQuota, found := quotaCodeToValueMap[_vpcQuotaCode]; found { + err := awsClient.VerifyVPCQuota(vpcQuota, requiredVPCs) + if err != nil { + return err + } + } else { + skippedValidations = append(skippedValidations, fmt.Sprintf("skipping vpc quota verification: unable to find vpc quota (%s)\n", _vpcQuotaCode)) + } + } + + if securityGroupRulesQuota, found := quotaCodeToValueMap[_securityGroupRulesQuotaCode]; found { + err := awsClient.VerifySecurityGroupRulesQuota(securityGroupRulesQuota, availabilityZones, numNodeGroups, longestCIDRWhiteList) + if err != nil { + return err + } + } else { + skippedValidations = append(skippedValidations, fmt.Sprintf("skipping security group rules quota verification: unable to find security group rules quota (%s)\n", _securityGroupRulesQuotaCode)) + } + + if securityGroupsQuota, found := quotaCodeToValueMap[_securityGroupsQuotaCode]; found { + err := awsClient.VerifySecurityGroupQuota(securityGroupsQuota, numNodeGroups) + if err != nil { + return err + } + } else { + skippedValidations = append(skippedValidations, fmt.Sprintf("skipping security group quota verification: unable to find security group quota (%s)\n", _securityGroupsQuotaCode)) + } + + return nil +}