From d9d7a714cdf53b7521013a80e4239e5f92d6a63f Mon Sep 17 00:00:00 2001
From: Jack O'Sullivan <jackos1998@gmail.com>
Date: Mon, 11 Dec 2023 14:59:40 +0000
Subject: [PATCH] nixos/firewall: Add ability to forward per external IP

---
 nixos/boxes/colony/vms/estuary/default.nix |  3 +-
 nixos/modules/firewall.nix                 | 87 +++++++++++++++-------
 2 files changed, 61 insertions(+), 29 deletions(-)

diff --git a/nixos/boxes/colony/vms/estuary/default.nix b/nixos/boxes/colony/vms/estuary/default.nix
index bc1fa07..b44a815 100644
--- a/nixos/boxes/colony/vms/estuary/default.nix
+++ b/nixos/boxes/colony/vms/estuary/default.nix
@@ -356,8 +356,7 @@ in
                 nat = {
                   enable = true;
                   externalInterface = "wan";
-                  externalIP = assignments.internal.ipv4.address;
-                  forwardPorts = [
+                  forwardPorts."${assignments.internal.ipv4.address}" = [
                     {
                       port = "http";
                       dst = allAssignments.middleman.internal.ipv4.address;
diff --git a/nixos/modules/firewall.nix b/nixos/modules/firewall.nix
index 30ee703..961e6dd 100644
--- a/nixos/modules/firewall.nix
+++ b/nixos/modules/firewall.nix
@@ -1,6 +1,9 @@
 { lib, options, config, ... }:
 let
-  inherit (lib) optionalString concatStringsSep concatMapStringsSep optionalAttrs mkIf mkDefault mkMerge mkOverride;
+  inherit (builtins) typeOf replaceStrings attrNames;
+  inherit (lib)
+    optionalString concatStringsSep concatMapStringsSep mapAttrsToList optionalAttrs mkIf
+    mkDefault mkMerge mkOverride;
   inherit (lib.my) isIPv6 mkOpt' mkBoolOpt';
 
   allowICMP = ''
@@ -63,8 +66,8 @@ in
 
     nat = with options.networking.nat; {
       enable = mkBoolOpt' true "Whether to enable IP forwarding and NAT.";
-      inherit externalInterface externalIP;
-      forwardPorts = mkOpt' (listOf (submodule forwardOpts)) [ ] "List of port forwards.";
+      inherit externalInterface;
+      forwardPorts = mkOpt' (either (listOf (submodule forwardOpts)) (attrsOf (listOf (submodule forwardOpts)))) [ ] "IPv4 port forwards";
     };
   };
 
@@ -144,11 +147,16 @@ in
         };
       };
     }
-    (mkIf cfg.nat.enable {
+    (mkIf cfg.nat.enable (
+    let
+      iifForward = typeOf cfg.nat.forwardPorts == "list" && cfg.nat.forwardPorts != [ ];
+      dipForward = typeOf cfg.nat.forwardPorts == "set" && cfg.nat.forwardPorts != { };
+    in
+    {
       assertions = [
         {
-          assertion = with cfg.nat; (forwardPorts != [ ]) -> (externalInterface != null);
-          message = "my.firewall.nat.forwardPorts requires my.firewall.nat.external{Interface,IP}";
+          assertion = with cfg.nat; iifForward -> (externalInterface != null);
+          message = "my.firewall.nat.forwardPorts as list requires my.firewall.nat.externalInterface";
         }
       ];
 
@@ -171,43 +179,68 @@ in
 
       my.firewall.extraRules =
         let
+          ipK = ip: "ip${optionalString (isIPv6 ip) "6"}";
+          ipEscaped = replaceStrings ["." ":"] ["-" "-"];
+
           makeFilter = f:
-          let
-            v6 = isIPv6 f.dst;
-          in
-            "ip${optionalString v6 "6"} daddr ${f.dst} ${f.proto} dport ${toString f.dstPort} accept";
+            "${ipK f.dst} daddr ${f.dst} ${f.proto} dport ${toString f.dstPort} accept";
           makeForward = f:
-            let
-              v6 = isIPv6 f.dst;
-            in
-              "${f.proto} dport ${toString f.port} dnat ip${optionalString v6 "6"} to ${f.dst}:${toString f.dstPort}";
+            "${f.proto} dport ${toString f.port} dnat ${ipK f.dst} to ${f.dst}:${toString f.dstPort}";
         in
         ''
           table inet filter {
-            chain filter-port-forwards {
-              ${concatMapStringsSep "\n    " makeFilter cfg.nat.forwardPorts}
-              return
-            }
+            ${optionalString iifForward ''
+              chain filter-iif-port-forwards {
+                ${concatMapStringsSep "\n    " makeFilter cfg.nat.forwardPorts}
+                return
+              }
+            ''}
+            ${optionalString
+              dipForward
+              (concatStringsSep "\n" (mapAttrsToList (ip: fs: ''
+                chain filter-fwd-${ipEscaped ip} {
+                  ${concatMapStringsSep "\n    " makeFilter fs}
+                  return
+                }
+              '') cfg.nat.forwardPorts))}
+
             chain forward {
               ${optionalString
-                (cfg.nat.externalInterface != null)
-                "iifname ${cfg.nat.externalInterface} jump filter-port-forwards"}
+                iifForward
+                "iifname ${cfg.nat.externalInterface} jump filter-iif-port-forwards"}
+              ${optionalString
+                dipForward
+                (concatMapStringsSep "\n    " (ip: "${ipK ip} daddr ${ip} jump filter-fwd-${ipEscaped ip}") (attrNames cfg.nat.forwardPorts))}
             }
           }
 
           table inet nat {
-            chain port-forward {
-              ${concatMapStringsSep "\n    " makeForward cfg.nat.forwardPorts}
-              return
-            }
+            ${optionalString iifForward ''
+              chain iif-port-forward {
+                ${concatMapStringsSep "\n    " makeForward cfg.nat.forwardPorts}
+                return
+              }
+            ''}
+            ${optionalString
+              dipForward
+              (concatStringsSep "\n" (mapAttrsToList (ip: fs: ''
+                chain fwd-${ipEscaped ip} {
+                  ${concatMapStringsSep "\n    " makeForward fs}
+                  return
+                }
+              '') cfg.nat.forwardPorts))}
+
             chain prerouting {
               ${optionalString
-                (cfg.nat.externalInterface != null)
-                "${if (cfg.nat.externalIP != null) then "ip daddr ${cfg.nat.externalIP}" else "iifname ${cfg.nat.externalInterface}"} jump port-forward"}
+                iifForward
+                "iifname ${cfg.nat.externalInterface} jump iif-port-forward"}
+              ${optionalString
+                dipForward
+                (concatMapStringsSep "\n    " (ip: "${ipK ip} daddr ${ip} jump fwd-${ipEscaped ip}") (attrNames cfg.nat.forwardPorts))}
             }
           }
         '';
-    })
+    }))
   ]);
 
   meta.buildDocsInSandbox = false;