xref: /linux-tools/ebpf/kernel-samples-bpf/l3lb_kern.c (revision 7889a9fe3f6cd23238c94fad4e1c698d5585c3fe)
1 #include <linux/bpf.h>
2 #include <linux/in.h>
3 #include <linux/if_ether.h>
4 #include <linux/if_packet.h>
5 #include <linux/if_vlan.h>
6 #include <linux/ip.h>
7 #include <bpf/bpf_helpers.h>
8 #include "l3lb.h"
9 #include <uapi/linux/tcp.h>
10 
11 struct {
12 	__uint(type, BPF_MAP_TYPE_HASH);
13  	__uint(key_size, sizeof(unsigned int));
14 	__uint(value_size, sizeof(BNode));
15 	__uint(max_entries, 4096);
16     __uint(pinning, LIBBPF_PIN_BY_NAME);
17 } l3bindings SEC(".maps");
18 
19 struct {
20 	__uint(type, BPF_MAP_TYPE_ARRAY);
21 	__uint(key_size, sizeof(int));
22 	__uint(value_size, sizeof(int));
23 	__uint(max_entries, 1);
24 } tx_addr SEC(".maps");
25 
26 
27 static __always_inline
swap_src_dst_mac(void * data)28 void swap_src_dst_mac(void *data)
29 {
30 	unsigned short *p = data;
31 	unsigned short dst[3];
32 
33 	dst[0] = p[0];
34 	dst[1] = p[1];
35 	dst[2] = p[2];
36 	p[0] = p[3];
37 	p[1] = p[4];
38 	p[2] = p[5];
39 	p[3] = dst[0];
40 	p[4] = dst[1];
41 	p[5] = dst[2];
42 }
43 
_check(unsigned int ip)44 static __always_inline BNode *_check(unsigned int ip) {
45     unsigned int m = 0xffffffff;
46     int i;
47     BNode *r;
48     ip = (ip>>24) | (((ip>>16)&0xff)<<8) | (((ip>>8)&0xff)<<16) | ((ip&0xff)<<24);
49     for (i=0; i<31; i++) {
50         ip &= m;
51         r=bpf_map_lookup_elem(&l3bindings, &ip);
52         if (r) return r;
53         m<<=1;
54     }
55     return NULL;
56 }
57 
iphdr_adjust_csum(struct iphdr * iph,int dsum)58 static __always_inline void  iphdr_adjust_csum(struct iphdr* iph, int dsum) {
59     if (dsum == 0) return;
60     unsigned int csum = (~iph->check)&0xffff;
61     if (dsum < 0) {
62         dsum = -dsum;
63         if (csum >= dsum) csum-=dsum;
64         else csum += (0xffff-dsum);
65     } else if (dsum > 0) {
66         csum+=dsum;
67         if (csum>0xffff) csum = (csum&0xffff)+1;
68     }
69     iph->check = ~csum;
70 }
71 
72 
73 SEC("xdp")
xdp_ipv4_l3lb(struct xdp_md * ctx)74 int xdp_ipv4_l3lb(struct xdp_md *ctx)
75 {
76 	void *data_end = (void *)(long)ctx->data_end;
77 	void *data = (void *)(long)ctx->data;
78 	struct ethhdr *eth = data;
79     struct iphdr *iph;
80 	u16 h_proto;
81 	u64 nh_off;
82     BNode *value = NULL;
83     int i, o_ifindex=0, rc;
84 	struct bpf_fib_lookup fib_params;
85     unsigned int *pip = NULL;
86     int dsum=0;
87     // struct tcphdr *tcph = NULL;
88 
89 	nh_off = sizeof(*eth);
90 	if (data + nh_off > data_end) return XDP_PASS;
91 
92 	h_proto = eth->h_proto;
93 
94 	if (h_proto == htons(ETH_P_8021Q) || h_proto == htons(ETH_P_8021AD)) {
95 		struct vlan_hdr *vhdr;
96 
97 		vhdr = data + nh_off;
98 		nh_off += sizeof(struct vlan_hdr);
99 		if (data + nh_off > data_end)
100 			return XDP_PASS;
101 		h_proto = vhdr->h_vlan_encapsulated_proto;
102 	}
103 	if (h_proto == htons(ETH_P_8021Q) || h_proto == htons(ETH_P_8021AD)) {
104 		struct vlan_hdr *vhdr;
105 
106 		vhdr = data + nh_off;
107 		nh_off += sizeof(struct vlan_hdr);
108 		if (data + nh_off > data_end)
109 			return XDP_PASS;
110 		h_proto = vhdr->h_vlan_encapsulated_proto;
111 	}
112 	if (h_proto != htons(ETH_P_IP)) return XDP_PASS;
113     if (data + nh_off + sizeof(*iph) > data_end) return XDP_PASS;
114 	iph = data + nh_off;
115     if (iph->ttl <= 1) return XDP_PASS;
116 
117     i=0;
118     pip = bpf_map_lookup_elem(&tx_addr, &i);
119     if (pip == NULL) return XDP_PASS;
120 
121 	__builtin_memset(&fib_params, 0, sizeof(fib_params));
122     fib_params.family	= AF_INET;
123     fib_params.tos		= iph->tos;
124     // fib_params.l4_protocol	= iph->protocol;
125     fib_params.tot_len	= ntohs(iph->tot_len);
126 
127     if (iph->daddr == *pip) {
128         value =  _check(iph->saddr);
129         if (!value) return XDP_PASS;
130         if (ctx->ingress_ifindex != value->ifin) return XDP_PASS;
131         // change daddr and  redirect
132         // bpf_printk("capture source ip %x -> %x %x\n", iph->saddr, iph->daddr, value->daddr);
133         fib_params.ipv4_src	= *pip; // iph->saddr;
134 		fib_params.ipv4_dst	= value->daddr;
135         o_ifindex = value->ifout;
136         dsum = (value->daddr&0xffff) + (value->daddr>>16);
137         dsum -= (iph->daddr&0xffff) + (iph->daddr>>16);
138         iph->daddr = value->daddr;
139     } else {
140         value = _check(iph->daddr);
141         if (!value) return XDP_PASS;
142         if (ctx->ingress_ifindex != value->ifout) return XDP_PASS;
143         // change saddr
144         // bpf_printk("capture dest ip %x <> %x %x\n", iph->saddr, iph->daddr, value->saddr);
145         fib_params.ipv4_src	= value->saddr;
146 		fib_params.ipv4_dst	= iph->daddr;
147         o_ifindex = value->ifin;
148         dsum = (value->saddr&0xffff) + (value->saddr>>16);
149         dsum -= (iph->saddr&0xffff) + (iph->saddr>>16);
150         iph->saddr = value->saddr;
151     }
152     iph->ttl--;
153     dsum -= htons(0x0100);
154     iphdr_adjust_csum(iph, dsum);
155     // iph->protocol == IPPROTO_TCP tcp->check ?
156 
157 	fib_params.ifindex = ctx->ingress_ifindex; // o_ifindex;
158 	rc = bpf_fib_lookup(ctx, &fib_params, sizeof(fib_params), BPF_FIB_LOOKUP_DIRECT); // 0, BPF_FIB_LOOKUP_DIRECT, or BPF_FIB_LOOKUP_OUTPUT
159 	if (rc == BPF_FIB_LKUP_RET_SUCCESS) {
160 		memcpy(eth->h_dest, fib_params.dmac, ETH_ALEN);
161 		memcpy(eth->h_source, fib_params.smac, ETH_ALEN);
162 		return  bpf_redirect(o_ifindex, 0);
163 	} else {
164         bpf_printk("fib failed: %d, try sending icmp ping?\n", rc);
165     }
166 	return XDP_PASS;
167 }
168 
169 
170 char LICENSE[] SEC("license") = "Dual BSD/GPL";
171 
172