15 """Starts a local DNS server for use in tests"""
26 import twisted.internet
27 import twisted.internet.defer
28 import twisted.internet.protocol
29 import twisted.internet.reactor
30 import twisted.internet.threads
32 from twisted.names
import authority
33 from twisted.names
import client
34 from twisted.names
import common
35 from twisted.names
import dns
36 from twisted.names
import server
37 import twisted.names.client
38 import twisted.names.dns
39 import twisted.names.server
42 _SERVER_HEALTH_CHECK_RECORD_NAME =
'health-check-local-dns-server-is-alive.resolver-tests.grpctestingexp'
43 _SERVER_HEALTH_CHECK_RECORD_DATA =
'123.123.123.123'
50 common.ResolverBase.__init__(self)
58 def _push_record(name, r):
59 name = name.encode(
'ascii')
60 print(
'pushing record: |%s|' % name)
61 if all_records.get(name)
is not None:
62 all_records[name].append(r)
64 all_records[name] = [r]
66 def _maybe_split_up_txt_data(name, txt_data, r_ttl):
67 txt_data = txt_data.encode(
'ascii')
70 while len(txt_data[start:]) > 0:
71 next_read =
len(txt_data[start:])
74 txt_data_list.append(txt_data[start:start + next_read])
76 _push_record(name, dns.Record_TXT(*txt_data_list, ttl=r_ttl))
78 with open(args.records_config_path)
as config:
79 test_records_config = yaml.load(config)
80 common_zone_name = test_records_config[
'resolver_tests_common_zone_name']
81 for group
in test_records_config[
'resolver_component_tests']:
82 for name
in group[
'records'].
keys():
83 for record
in group[
'records'][name]:
84 r_type = record[
'type']
85 r_data = record[
'data']
86 r_ttl =
int(record[
'TTL'])
87 record_full_name =
'%s.%s' % (name, common_zone_name)
88 assert record_full_name[-1] ==
'.'
89 record_full_name = record_full_name[:-1]
91 _push_record(record_full_name,
92 dns.Record_A(r_data, ttl=r_ttl))
94 _push_record(record_full_name,
95 dns.Record_AAAA(r_data, ttl=r_ttl))
97 p, w, port, target = r_data.split(
' ')
102 '%s.%s' % (target, common_zone_name)).
encode(
'ascii')
105 dns.Record_SRV(p, w, port, target_full_name, ttl=r_ttl))
107 _maybe_split_up_txt_data(record_full_name, r_data, r_ttl)
109 if args.add_a_record:
110 extra_host, extra_host_ipv4 = args.add_a_record.split(
':')
111 _push_record(extra_host, dns.Record_A(extra_host_ipv4, ttl=0))
113 _push_record(_SERVER_HEALTH_CHECK_RECORD_NAME,
114 dns.Record_A(_SERVER_HEALTH_CHECK_RECORD_DATA, ttl=0))
115 soa_record = dns.Record_SOA(mname=common_zone_name.encode(
'ascii'))
117 soa=(common_zone_name.encode(
'ascii'), soa_record),
120 server = twisted.names.server.DNSServerFactory(
121 authorities=[test_domain_com], verbose=2)
123 twisted.internet.reactor.listenTCP(args.port, server)
124 dns_proto = twisted.names.dns.DNSDatagramProtocol(server)
126 twisted.internet.reactor.listenUDP(args.port, dns_proto)
127 print(
'starting local dns server on 127.0.0.1:%s' % args.port)
128 print(
'starting twisted.internet.reactor')
129 twisted.internet.reactor.suggestThreadPoolSize(1)
130 twisted.internet.reactor.run()
134 print(
'Received SIGNAL %d. Quitting with exit code 0' % signum)
135 twisted.internet.reactor.stop()
141 num_timeouts_so_far = 0
144 max_timeouts = 60 * 10
145 while num_timeouts_so_far < max_timeouts:
147 time.sleep(sleep_time)
148 num_timeouts_so_far += 1
149 print(
'Process timeout reached, or cancelled. Exitting 0.')
150 os.kill(os.getpid(), signal.SIGTERM)
154 argp = argparse.ArgumentParser(
155 description=
'Local DNS Server for resolver tests')
156 argp.add_argument(
'-p',
160 help=
'Port for DNS server to listen on for TCP and UDP.')
163 '--records_config_path',
166 help=(
'Directory of resolver_test_record_groups.yaml file. '
167 'Defaults to path needed when the test is invoked as part '
173 help=(
'Add an A record via the command line. Useful for when we '
174 'need to serve a one-off A record that is under a '
175 'different domain then the rest the records configured in '
176 '--records_config_path (which all need to be under the '
177 'same domain). Format: <name>:<ipv4 address>'))
178 args = argp.parse_args()
179 signal.signal(signal.SIGTERM, _quit_on_signal)
180 signal.signal(signal.SIGINT, _quit_on_signal)
181 output_flush_thread = threading.Thread(target=flush_stdout_loop)
182 output_flush_thread.setDaemon(
True)
183 output_flush_thread.start()
187 if __name__ ==
'__main__':