1 #include <linux/bpf.h>
2 #include <linux/in.h>
3 #include <linux/if_ether.h>
4 #include <linux/if_packet.h>
5 #include <linux/if_vlan.h>
6 #include <linux/ip.h>
7 #include <bpf/bpf_helpers.h>
8 #include "l3lb.h"
9 #include <uapi/linux/tcp.h>
10
11 struct {
12 __uint(type, BPF_MAP_TYPE_HASH);
13 __uint(key_size, sizeof(unsigned int));
14 __uint(value_size, sizeof(BNode));
15 __uint(max_entries, 4096);
16 __uint(pinning, LIBBPF_PIN_BY_NAME);
17 } l3bindings SEC(".maps");
18
19 struct {
20 __uint(type, BPF_MAP_TYPE_ARRAY);
21 __uint(key_size, sizeof(int));
22 __uint(value_size, sizeof(int));
23 __uint(max_entries, 1);
24 } tx_addr SEC(".maps");
25
26
27 static __always_inline
swap_src_dst_mac(void * data)28 void swap_src_dst_mac(void *data)
29 {
30 unsigned short *p = data;
31 unsigned short dst[3];
32
33 dst[0] = p[0];
34 dst[1] = p[1];
35 dst[2] = p[2];
36 p[0] = p[3];
37 p[1] = p[4];
38 p[2] = p[5];
39 p[3] = dst[0];
40 p[4] = dst[1];
41 p[5] = dst[2];
42 }
43
_check(unsigned int ip)44 static __always_inline BNode *_check(unsigned int ip) {
45 unsigned int m = 0xffffffff;
46 int i;
47 BNode *r;
48 ip = (ip>>24) | (((ip>>16)&0xff)<<8) | (((ip>>8)&0xff)<<16) | ((ip&0xff)<<24);
49 for (i=0; i<31; i++) {
50 ip &= m;
51 r=bpf_map_lookup_elem(&l3bindings, &ip);
52 if (r) return r;
53 m<<=1;
54 }
55 return NULL;
56 }
57
iphdr_adjust_csum(struct iphdr * iph,int dsum)58 static __always_inline void iphdr_adjust_csum(struct iphdr* iph, int dsum) {
59 if (dsum == 0) return;
60 unsigned int csum = (~iph->check)&0xffff;
61 if (dsum < 0) {
62 dsum = -dsum;
63 if (csum >= dsum) csum-=dsum;
64 else csum += (0xffff-dsum);
65 } else if (dsum > 0) {
66 csum+=dsum;
67 if (csum>0xffff) csum = (csum&0xffff)+1;
68 }
69 iph->check = ~csum;
70 }
71
72
73 SEC("xdp")
xdp_ipv4_l3lb(struct xdp_md * ctx)74 int xdp_ipv4_l3lb(struct xdp_md *ctx)
75 {
76 void *data_end = (void *)(long)ctx->data_end;
77 void *data = (void *)(long)ctx->data;
78 struct ethhdr *eth = data;
79 struct iphdr *iph;
80 u16 h_proto;
81 u64 nh_off;
82 BNode *value = NULL;
83 int i, o_ifindex=0, rc;
84 struct bpf_fib_lookup fib_params;
85 unsigned int *pip = NULL;
86 int dsum=0;
87 // struct tcphdr *tcph = NULL;
88
89 nh_off = sizeof(*eth);
90 if (data + nh_off > data_end) return XDP_PASS;
91
92 h_proto = eth->h_proto;
93
94 if (h_proto == htons(ETH_P_8021Q) || h_proto == htons(ETH_P_8021AD)) {
95 struct vlan_hdr *vhdr;
96
97 vhdr = data + nh_off;
98 nh_off += sizeof(struct vlan_hdr);
99 if (data + nh_off > data_end)
100 return XDP_PASS;
101 h_proto = vhdr->h_vlan_encapsulated_proto;
102 }
103 if (h_proto == htons(ETH_P_8021Q) || h_proto == htons(ETH_P_8021AD)) {
104 struct vlan_hdr *vhdr;
105
106 vhdr = data + nh_off;
107 nh_off += sizeof(struct vlan_hdr);
108 if (data + nh_off > data_end)
109 return XDP_PASS;
110 h_proto = vhdr->h_vlan_encapsulated_proto;
111 }
112 if (h_proto != htons(ETH_P_IP)) return XDP_PASS;
113 if (data + nh_off + sizeof(*iph) > data_end) return XDP_PASS;
114 iph = data + nh_off;
115 if (iph->ttl <= 1) return XDP_PASS;
116
117 i=0;
118 pip = bpf_map_lookup_elem(&tx_addr, &i);
119 if (pip == NULL) return XDP_PASS;
120
121 __builtin_memset(&fib_params, 0, sizeof(fib_params));
122 fib_params.family = AF_INET;
123 fib_params.tos = iph->tos;
124 // fib_params.l4_protocol = iph->protocol;
125 fib_params.tot_len = ntohs(iph->tot_len);
126
127 if (iph->daddr == *pip) {
128 value = _check(iph->saddr);
129 if (!value) return XDP_PASS;
130 if (ctx->ingress_ifindex != value->ifin) return XDP_PASS;
131 // change daddr and redirect
132 // bpf_printk("capture source ip %x -> %x %x\n", iph->saddr, iph->daddr, value->daddr);
133 fib_params.ipv4_src = *pip; // iph->saddr;
134 fib_params.ipv4_dst = value->daddr;
135 o_ifindex = value->ifout;
136 dsum = (value->daddr&0xffff) + (value->daddr>>16);
137 dsum -= (iph->daddr&0xffff) + (iph->daddr>>16);
138 iph->daddr = value->daddr;
139 } else {
140 value = _check(iph->daddr);
141 if (!value) return XDP_PASS;
142 if (ctx->ingress_ifindex != value->ifout) return XDP_PASS;
143 // change saddr
144 // bpf_printk("capture dest ip %x <> %x %x\n", iph->saddr, iph->daddr, value->saddr);
145 fib_params.ipv4_src = value->saddr;
146 fib_params.ipv4_dst = iph->daddr;
147 o_ifindex = value->ifin;
148 dsum = (value->saddr&0xffff) + (value->saddr>>16);
149 dsum -= (iph->saddr&0xffff) + (iph->saddr>>16);
150 iph->saddr = value->saddr;
151 }
152 iph->ttl--;
153 dsum -= htons(0x0100);
154 iphdr_adjust_csum(iph, dsum);
155 // iph->protocol == IPPROTO_TCP tcp->check ?
156
157 fib_params.ifindex = ctx->ingress_ifindex; // o_ifindex;
158 rc = bpf_fib_lookup(ctx, &fib_params, sizeof(fib_params), BPF_FIB_LOOKUP_DIRECT); // 0, BPF_FIB_LOOKUP_DIRECT, or BPF_FIB_LOOKUP_OUTPUT
159 if (rc == BPF_FIB_LKUP_RET_SUCCESS) {
160 memcpy(eth->h_dest, fib_params.dmac, ETH_ALEN);
161 memcpy(eth->h_source, fib_params.smac, ETH_ALEN);
162 return bpf_redirect(o_ifindex, 0);
163 } else {
164 bpf_printk("fib failed: %d, try sending icmp ping?\n", rc);
165 }
166 return XDP_PASS;
167 }
168
169
170 char LICENSE[] SEC("license") = "Dual BSD/GPL";
171
172