diff --git a/testenv/Makefile.am b/testenv/Makefile.am index 39ae3497..8f1b5f4a 100644 --- a/testenv/Makefile.am +++ b/testenv/Makefile.am @@ -56,6 +56,9 @@ if HAVE_PYTHON3 Test-cookie-expires.py \ Test-cookie.py \ Test-Head.py \ + Test--https.py \ + Test--https-crl.py \ + Test-hsts.py \ Test-O.py \ Test-Post.py \ Test-504.py \ diff --git a/testenv/Test-hsts.py b/testenv/Test-hsts.py new file mode 100755 index 00000000..42909294 --- /dev/null +++ b/testenv/Test-hsts.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python3 +from sys import exit +from test.http_test import HTTPTest +from test.base_test import HTTP, HTTPS +from misc.wget_file import WgetFile +import time +import os + +""" +This test makes sure Wget can parse a given HSTS database and apply the indicated HSTS policy. +""" +def hsts_database_path(): + hsts_file = ".wget-hsts-testenv" + return os.path.abspath(hsts_file) + +def create_hsts_database(path, host, port): + # we want the current time as an integer, + # not as a floating point + curtime = int(time.time()) + max_age = "123456" + + f = open(path, "w") + + f.write("# dummy comment\n") + f.write(host + "\t" + str(port) + "\t0\t" + str(curtime) + "\t" + max_age + "\n") + f.close() + +TEST_NAME = "HSTS basic test" + +File_Name = "hw" +File_Content = "Hello, world!" +File = WgetFile(File_Name, File_Content) + +Hsts_File_Path = hsts_database_path() + +CAFILE = os.path.abspath(os.path.join(os.getenv('srcdir', '.'), 'certs', 'ca-cert.pem')) + +WGET_OPTIONS = "--hsts-file=" + Hsts_File_Path + " --ca-certificate=" + CAFILE +WGET_URLS = [[File_Name]] + +Files = [[File]] +Servers = [HTTPS] +Requests = ["http"] + +ExpectedReturnCode = 0 +ExpectedDownloadedFiles = [File] + +pre_test = { + "ServerFiles" : Files, + "Domains" : ["localhost"] +} +post_test = { + "ExpectedFiles" : ExpectedDownloadedFiles, + "ExpectedRetCode" : ExpectedReturnCode, +} +test_options = { + "WgetCommands" : WGET_OPTIONS, + "Urls" : WGET_URLS +} + +test = HTTPTest( + name = TEST_NAME, + pre_hook = pre_test, + post_hook = post_test, + test_params = test_options, + protocols = Servers, + req_protocols = Requests +) + +# start the web server and create the temporary HSTS database +test.setup() +create_hsts_database(Hsts_File_Path, 'localhost', test.port) + +err = test.begin() + +# remove the temporary HSTS database +os.unlink(hsts_database_path()) +exit(err) diff --git a/testenv/conf/domains.py b/testenv/conf/domains.py new file mode 100644 index 00000000..ac03fe19 --- /dev/null +++ b/testenv/conf/domains.py @@ -0,0 +1,9 @@ +from conf import hook + +@hook(alias='Domains') +class Domains: + def __init__(self, domains): + self.domains = domains + + def __call__(self, test_obj): + test_obj.domains = self.domains diff --git a/testenv/test/base_test.py b/testenv/test/base_test.py index 3b989100..c5b82bea 100644 --- a/testenv/test/base_test.py +++ b/testenv/test/base_test.py @@ -22,7 +22,7 @@ class BaseTest: * instantiate_server_by(protocol) """ - def __init__(self, name, pre_hook, test_params, post_hook, protocols): + def __init__(self, name, pre_hook, test_params, post_hook, protocols, req_protocols): """ Define the class-wide variables (or attributes). Attributes should not be defined outside __init__. @@ -36,14 +36,23 @@ class BaseTest: self.post_configs = post_hook or {} self.protocols = protocols + if req_protocols is None: + self.req_protocols = map(lambda p: p.lower(), self.protocols) + else: + self.req_protocols = req_protocols + self.servers = [] self.domains = [] + self.ports = [] + + self.addr = None self.port = -1 self.wget_options = '' self.urls = [] self.tests_passed = True + self.ready = False self.init_test_env() self.ret_code = 0 @@ -63,9 +72,12 @@ class BaseTest: def get_domain_addr(self, addr): # TODO if there's a multiple number of ports, wouldn't it be # overridden to the port of the last invocation? + # Set the instance variables 'addr' and 'port' so that + # they can be queried by test cases. + self.addr = str(addr[0]) self.port = str(addr[1]) - return '%s:%s' % (addr[0], self.port) + return [self.addr, self.port] def server_setup(self): print_blue("Running Test %s" % self.name) @@ -77,7 +89,8 @@ class BaseTest: # ports and etc. # so we should record different domains respect to servers. domain = self.get_domain_addr(instance.server_address) - self.domains.append(domain) + self.domains.append(domain[0]) + self.ports.append(domain[1]) def exec_wget(self): cmd_line = self.gen_cmd_line() @@ -122,9 +135,10 @@ class BaseTest: else: cmd_line = '%s %s ' % (wget_path, wget_options) - for protocol, urls, domain in zip(self.protocols, - self.urls, - self.domains): + for req_protocol, urls, domain, port in zip(self.req_protocols, + self.urls, + self.domains, + self.ports): # zip is function for iterating multiple lists at the same time. # e.g. for item1, item2 in zip([1, 5, 3], # ['a', 'e', 'c']): @@ -134,7 +148,8 @@ class BaseTest: # 5 e # 3 c for url in urls: - cmd_line += '%s://%s/%s ' % (protocol.lower(), domain, url) + cmd_line += '%s://%s:%s/%s ' % (req_protocol, domain, port, url) + print(cmd_line) diff --git a/testenv/test/http_test.py b/testenv/test/http_test.py index 32a3335a..5a13d4f6 100644 --- a/testenv/test/http_test.py +++ b/testenv/test/http_test.py @@ -17,19 +17,34 @@ class HTTPTest(BaseTest): pre_hook=None, test_params=None, post_hook=None, - protocols=(HTTP,)): + protocols=(HTTP,), + req_protocols=None): super(HTTPTest, self).__init__(name, pre_hook, test_params, post_hook, - protocols) + protocols, + req_protocols) + + def setup(self): self.server_setup() + self.ready = True def begin(self): + if not self.ready: + # this is to maintain compatibility with scripts that + # don't call setup() + self.setup() with self: - self.do_test() - print_green('Test Passed.') - return super(HTTPTest, self).begin() + # If any exception occurs, self.__exit__ will be immediately called. + # We must call the parent method in the end in order to verify + # whether the tests succeeded or not. + if self.ready: + self.do_test() + print_green("Test Passed.") + else: + self.tests_passed = False + super(HTTPTest, self).begin() def instantiate_server_by(self, protocol): server = {HTTP: HTTPd,