Commit 686a8a4f authored by Christof Schulze's avatar Christof Schulze 😎
Browse files

added path.split_path

parent 452f7cf9
Loading
Loading
Loading
Loading
+26 −0
Original line number Diff line number Diff line
@@ -134,3 +134,29 @@ def cleanup_tmp_file(path, warn=False):
                    display.display('Unable to remove temporary file {0}'.format(to_text(e)))
    except Exception:
        pass


def split_path(path: str) -> list:
    """
    Returns all elements of a path separated in its elements.
    :param path:
    :return: list of elements of a path.
    """
    split_path.parts = []

    def _split(path: str):
        rough_split = os.path.split(path)
        cleared_split = list(filter(None, rough_split))

        if len(rough_split) != len(cleared_split):
            if len(cleared_split[0]) == len(path):
                split_path.parts.append(cleared_split[0])
            else:
                _split(cleared_split[0])
        else:
            split_path.parts.append(cleared_split[1])
            _split(cleared_split[0])

    _split(path)

    return list(reversed(split_path.parts))
+26 −0
Original line number Diff line number Diff line
import unittest
from ammsml.utils.path import split_path


class SimpleSplitTest(unittest.TestCase):
    def test_single(self):
        test_list1 = ['/tmp/test']
        for i in test_list1:
            self.assertEqual(split_path(i), ['/', 'tmp', 'test'])

    def test_single_end(self):
        test_list2 = ['/tmp/test/']
        for i in test_list2:
            self.assertEqual(split_path(i), ['/', 'tmp', 'test'])

    def test_single_end(self):
        test_list3 = ['/tmp/test', 'raw', 'test2', 'test3/test31', 'test3/test32/test321']
        self.assertEqual(split_path(test_list3[0]), ['/', 'tmp', 'test'])
        self.assertEqual(split_path(test_list3[1]), ['raw'])
        self.assertEqual(split_path(test_list3[2]), ['test2'])
        self.assertEqual(split_path(test_list3[3]), ['test3', 'test31'])
        self.assertEqual(split_path(test_list3[4]), ['test3', 'test32', 'test321'])


if __name__ == '__main__':
    unittest.main()