1 #include <linux/bpf.h> 2 #include <linux/in.h> 3 #include <linux/if_ether.h> 4 #include <linux/if_packet.h> 5 #include <linux/if_vlan.h> 6 #include <linux/ip.h> 7 #include <bpf/bpf_helpers.h> 8 #include "l3lb.h" 9 #include <uapi/linux/tcp.h> 10 11 struct { 12 __uint(type, BPF_MAP_TYPE_HASH); 13 __uint(key_size, sizeof(unsigned int)); 14 __uint(value_size, sizeof(BNode)); 15 __uint(max_entries, 4096); 16 __uint(pinning, LIBBPF_PIN_BY_NAME); 17 } l3bindings SEC(".maps"); 18 19 struct { 20 __uint(type, BPF_MAP_TYPE_ARRAY); 21 __uint(key_size, sizeof(int)); 22 __uint(value_size, sizeof(int)); 23 __uint(max_entries, 1); 24 } tx_addr SEC(".maps"); 25 26 27 static __always_inline 28 void swap_src_dst_mac(void *data) 29 { 30 unsigned short *p = data; 31 unsigned short dst[3]; 32 33 dst[0] = p[0]; 34 dst[1] = p[1]; 35 dst[2] = p[2]; 36 p[0] = p[3]; 37 p[1] = p[4]; 38 p[2] = p[5]; 39 p[3] = dst[0]; 40 p[4] = dst[1]; 41 p[5] = dst[2]; 42 } 43 44 static __always_inline BNode *_check(unsigned int ip) { 45 unsigned int m = 0xffffffff; 46 int i; 47 BNode *r; 48 ip = (ip>>24) | (((ip>>16)&0xff)<<8) | (((ip>>8)&0xff)<<16) | ((ip&0xff)<<24); 49 for (i=0; i<31; i++) { 50 ip &= m; 51 r=bpf_map_lookup_elem(&l3bindings, &ip); 52 if (r) return r; 53 m<<=1; 54 } 55 return NULL; 56 } 57 58 static __always_inline void iphdr_adjust_csum(struct iphdr* iph, int dsum) { 59 if (dsum == 0) return; 60 unsigned int csum = (~iph->check)&0xffff; 61 if (dsum < 0) { 62 dsum = -dsum; 63 if (csum >= dsum) csum-=dsum; 64 else csum += (0xffff-dsum); 65 } else if (dsum > 0) { 66 csum+=dsum; 67 if (csum>0xffff) csum = (csum&0xffff)+1; 68 } 69 iph->check = ~csum; 70 } 71 72 73 SEC("xdp") 74 int xdp_ipv4_l3lb(struct xdp_md *ctx) 75 { 76 void *data_end = (void *)(long)ctx->data_end; 77 void *data = (void *)(long)ctx->data; 78 struct ethhdr *eth = data; 79 struct iphdr *iph; 80 u16 h_proto; 81 u64 nh_off; 82 BNode *value = NULL; 83 int i, o_ifindex=0, rc; 84 struct bpf_fib_lookup fib_params; 85 unsigned int *pip = NULL; 86 int dsum=0; 87 // struct tcphdr *tcph = NULL; 88 89 nh_off = sizeof(*eth); 90 if (data + nh_off > data_end) return XDP_PASS; 91 92 h_proto = eth->h_proto; 93 94 if (h_proto == htons(ETH_P_8021Q) || h_proto == htons(ETH_P_8021AD)) { 95 struct vlan_hdr *vhdr; 96 97 vhdr = data + nh_off; 98 nh_off += sizeof(struct vlan_hdr); 99 if (data + nh_off > data_end) 100 return XDP_PASS; 101 h_proto = vhdr->h_vlan_encapsulated_proto; 102 } 103 if (h_proto == htons(ETH_P_8021Q) || h_proto == htons(ETH_P_8021AD)) { 104 struct vlan_hdr *vhdr; 105 106 vhdr = data + nh_off; 107 nh_off += sizeof(struct vlan_hdr); 108 if (data + nh_off > data_end) 109 return XDP_PASS; 110 h_proto = vhdr->h_vlan_encapsulated_proto; 111 } 112 if (h_proto != htons(ETH_P_IP)) return XDP_PASS; 113 if (data + nh_off + sizeof(*iph) > data_end) return XDP_PASS; 114 iph = data + nh_off; 115 if (iph->ttl <= 1) return XDP_PASS; 116 117 i=0; 118 pip = bpf_map_lookup_elem(&tx_addr, &i); 119 if (pip == NULL) return XDP_PASS; 120 121 __builtin_memset(&fib_params, 0, sizeof(fib_params)); 122 fib_params.family = AF_INET; 123 fib_params.tos = iph->tos; 124 // fib_params.l4_protocol = iph->protocol; 125 fib_params.tot_len = ntohs(iph->tot_len); 126 127 if (iph->daddr == *pip) { 128 value = _check(iph->saddr); 129 if (!value) return XDP_PASS; 130 if (ctx->ingress_ifindex != value->ifin) return XDP_PASS; 131 // change daddr and redirect 132 // bpf_printk("capture source ip %x -> %x %x\n", iph->saddr, iph->daddr, value->daddr); 133 fib_params.ipv4_src = *pip; // iph->saddr; 134 fib_params.ipv4_dst = value->daddr; 135 o_ifindex = value->ifout; 136 dsum = (value->daddr&0xffff) + (value->daddr>>16); 137 dsum -= (iph->daddr&0xffff) + (iph->daddr>>16); 138 iph->daddr = value->daddr; 139 } else { 140 value = _check(iph->daddr); 141 if (!value) return XDP_PASS; 142 if (ctx->ingress_ifindex != value->ifout) return XDP_PASS; 143 // change saddr 144 // bpf_printk("capture dest ip %x <> %x %x\n", iph->saddr, iph->daddr, value->saddr); 145 fib_params.ipv4_src = value->saddr; 146 fib_params.ipv4_dst = iph->daddr; 147 o_ifindex = value->ifin; 148 dsum = (value->saddr&0xffff) + (value->saddr>>16); 149 dsum -= (iph->saddr&0xffff) + (iph->saddr>>16); 150 iph->saddr = value->saddr; 151 } 152 iph->ttl--; 153 dsum -= htons(0x0100); 154 iphdr_adjust_csum(iph, dsum); 155 // iph->protocol == IPPROTO_TCP tcp->check ? 156 157 fib_params.ifindex = ctx->ingress_ifindex; // o_ifindex; 158 rc = bpf_fib_lookup(ctx, &fib_params, sizeof(fib_params), BPF_FIB_LOOKUP_DIRECT); // 0, BPF_FIB_LOOKUP_DIRECT, or BPF_FIB_LOOKUP_OUTPUT 159 if (rc == BPF_FIB_LKUP_RET_SUCCESS) { 160 memcpy(eth->h_dest, fib_params.dmac, ETH_ALEN); 161 memcpy(eth->h_source, fib_params.smac, ETH_ALEN); 162 return bpf_redirect(o_ifindex, 0); 163 } else { 164 bpf_printk("fib failed: %d, try sending icmp ping?\n", rc); 165 } 166 return XDP_PASS; 167 } 168 169 170 char LICENSE[] SEC("license") = "Dual BSD/GPL"; 171 172